Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mssql_database - support the threat_detection_policy property #6437

Merged
merged 6 commits into from
Apr 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 26 additions & 21 deletions azurerm/internal/services/mssql/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

type Client struct {
DatabasesClient *sql.DatabasesClient
DatabaseThreatDetectionPoliciesClient *sql.DatabaseThreatDetectionPoliciesClient
ElasticPoolsClient *sql.ElasticPoolsClient
DatabaseVulnerabilityAssessmentRuleBaselinesClient *sql.DatabaseVulnerabilityAssessmentRuleBaselinesClient
ServersClient *sql.ServersClient
Expand All @@ -17,34 +18,38 @@ type Client struct {
}

func NewClient(o *common.ClientOptions) *Client {
DatabasesClient := sql.NewDatabasesClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&DatabasesClient.Client, o.ResourceManagerAuthorizer)
databasesClient := sql.NewDatabasesClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&databasesClient.Client, o.ResourceManagerAuthorizer)

ElasticPoolsClient := sql.NewElasticPoolsClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&ElasticPoolsClient.Client, o.ResourceManagerAuthorizer)
databaseThreatDetectionPoliciesClient := sql.NewDatabaseThreatDetectionPoliciesClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&databaseThreatDetectionPoliciesClient.Client, o.ResourceManagerAuthorizer)

DatabaseVulnerabilityAssessmentRuleBaselinesClient := sql.NewDatabaseVulnerabilityAssessmentRuleBaselinesClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&DatabaseVulnerabilityAssessmentRuleBaselinesClient.Client, o.ResourceManagerAuthorizer)
databaseVulnerabilityAssessmentRuleBaselinesClient := sql.NewDatabaseVulnerabilityAssessmentRuleBaselinesClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&databaseVulnerabilityAssessmentRuleBaselinesClient.Client, o.ResourceManagerAuthorizer)

ServerSecurityAlertPoliciesClient := sql.NewServerSecurityAlertPoliciesClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&ServerSecurityAlertPoliciesClient.Client, o.ResourceManagerAuthorizer)
elasticPoolsClient := sql.NewElasticPoolsClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&elasticPoolsClient.Client, o.ResourceManagerAuthorizer)

ServersClient := sql.NewServersClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&ServersClient.Client, o.ResourceManagerAuthorizer)
serverSecurityAlertPoliciesClient := sql.NewServerSecurityAlertPoliciesClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&serverSecurityAlertPoliciesClient.Client, o.ResourceManagerAuthorizer)

ServerVulnerabilityAssessmentsClient := sql.NewServerVulnerabilityAssessmentsClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&ServerVulnerabilityAssessmentsClient.Client, o.ResourceManagerAuthorizer)
serversClient := sql.NewServersClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&serversClient.Client, o.ResourceManagerAuthorizer)

SQLVirtualMachinesClient := sqlvirtualmachine.NewSQLVirtualMachinesClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&SQLVirtualMachinesClient.Client, o.ResourceManagerAuthorizer)
serverVulnerabilityAssessmentsClient := sql.NewServerVulnerabilityAssessmentsClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&serverVulnerabilityAssessmentsClient.Client, o.ResourceManagerAuthorizer)

sqlVirtualMachinesClient := sqlvirtualmachine.NewSQLVirtualMachinesClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&sqlVirtualMachinesClient.Client, o.ResourceManagerAuthorizer)

return &Client{
DatabasesClient: &DatabasesClient,
ElasticPoolsClient: &ElasticPoolsClient,
DatabaseVulnerabilityAssessmentRuleBaselinesClient: &DatabaseVulnerabilityAssessmentRuleBaselinesClient,
ServersClient: &ServersClient,
ServerSecurityAlertPoliciesClient: &ServerSecurityAlertPoliciesClient,
ServerVulnerabilityAssessmentsClient: &ServerVulnerabilityAssessmentsClient,
VirtualMachinesClient: &SQLVirtualMachinesClient,
DatabasesClient: &databasesClient,
DatabaseThreatDetectionPoliciesClient: &databaseThreatDetectionPoliciesClient,
DatabaseVulnerabilityAssessmentRuleBaselinesClient: &databaseVulnerabilityAssessmentRuleBaselinesClient,
ElasticPoolsClient: &elasticPoolsClient,
ServersClient: &serversClient,
ServerSecurityAlertPoliciesClient: &serverSecurityAlertPoliciesClient,
ServerVulnerabilityAssessmentsClient: &serverVulnerabilityAssessmentsClient,
VirtualMachinesClient: &sqlVirtualMachinesClient,
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func dataSourceArmMsSqlDatabaseRead(d *schema.ResourceData, meta interface{}) er
return fmt.Errorf("Database %q (Resource Group %q, SQL Server %q) was not found", name, serverId.ResourceGroup, serverId.Name)
}

return fmt.Errorf("Failure in making Read request on AzureRM Database %s (Resource Group %q, SQL Server %q): %+v", name, serverId.ResourceGroup, serverId.Name, err)
return fmt.Errorf("making Read request on AzureRM Database %s (Resource Group %q, SQL Server %q): %+v", name, serverId.ResourceGroup, serverId.Name, err)
}

if id := resp.ID; id != nil {
Expand Down
224 changes: 213 additions & 11 deletions azurerm/internal/services/mssql/resource_arm_mssql_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package mssql
import (
"fmt"
"log"
"strings"
"time"

"github.com/Azure/azure-sdk-for-go/services/preview/sql/mgmt/v3.0/sql"
Expand Down Expand Up @@ -176,6 +177,92 @@ func resourceArmMsSqlDatabase() *schema.Resource {
Computed: true,
},

"threat_detection_policy": {
Type: schema.TypeList,
Optional: true,
Computed: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"disabled_alerts": {
Type: schema.TypeSet,
Optional: true,
Set: schema.HashString,
Elem: &schema.Schema{
Type: schema.TypeString,
ValidateFunc: validation.StringInSlice([]string{
"Sql_Injection",
"Sql_Injection_Vulnerability",
"Access_Anomaly",
}, true),
},
},

"email_account_admins": {
Type: schema.TypeString,
Optional: true,
DiffSuppressFunc: suppress.CaseDifference,
Default: string(sql.SecurityAlertPolicyEmailAccountAdminsDisabled),
ValidateFunc: validation.StringInSlice([]string{
string(sql.SecurityAlertPolicyEmailAccountAdminsDisabled),
string(sql.SecurityAlertPolicyEmailAccountAdminsEnabled),
}, true),
},

"email_addresses": {
Type: schema.TypeSet,
Optional: true,
Elem: &schema.Schema{
Type: schema.TypeString,
},
Set: schema.HashString,
},

"retention_days": {
Type: schema.TypeInt,
Optional: true,
ValidateFunc: validation.IntAtLeast(0),
},

"state": {
Type: schema.TypeString,
Optional: true,
DiffSuppressFunc: suppress.CaseDifference,
Default: string(sql.SecurityAlertPolicyStateDisabled),
ValidateFunc: validation.StringInSlice([]string{
string(sql.SecurityAlertPolicyStateDisabled),
string(sql.SecurityAlertPolicyStateEnabled),
string(sql.SecurityAlertPolicyStateNew),
}, true),
},

"storage_account_access_key": {
Type: schema.TypeString,
Optional: true,
Sensitive: true,
ValidateFunc: validation.StringIsNotEmpty,
},

"storage_endpoint": {
Type: schema.TypeString,
Optional: true,
ValidateFunc: validation.StringIsNotEmpty,
},

"use_server_default": {
Type: schema.TypeString,
Optional: true,
DiffSuppressFunc: suppress.CaseDifference,
Default: string(sql.SecurityAlertPolicyUseServerDefaultDisabled),
ValidateFunc: validation.StringInSlice([]string{
string(sql.SecurityAlertPolicyUseServerDefaultDisabled),
string(sql.SecurityAlertPolicyUseServerDefaultEnabled),
}, true),
},
},
},
},

"tags": tags.Schema(),
},
}
Expand All @@ -184,6 +271,7 @@ func resourceArmMsSqlDatabase() *schema.Resource {
func resourceArmMsSqlDatabaseCreateUpdate(d *schema.ResourceData, meta interface{}) error {
client := meta.(*clients.Client).MSSQL.DatabasesClient
serverClient := meta.(*clients.Client).MSSQL.ServersClient
threatClient := meta.(*clients.Client).MSSQL.DatabaseThreatDetectionPoliciesClient
ctx, cancel := timeouts.ForCreateUpdate(meta.(*clients.Client).StopContext, d)
defer cancel()

Expand All @@ -208,7 +296,7 @@ func resourceArmMsSqlDatabaseCreateUpdate(d *schema.ResourceData, meta interface

serverResp, err := serverClient.Get(ctx, serverId.ResourceGroup, serverId.Name)
if err != nil {
return fmt.Errorf("Failure in making Read request on MsSql Server %q (Resource Group %q): %s", serverId.Name, serverId.ResourceGroup, err)
return fmt.Errorf("making Read request on MsSql Server %q (Resource Group %q): %s", serverId.Name, serverId.ResourceGroup, err)
}

location := *serverResp.Location
Expand Down Expand Up @@ -272,16 +360,16 @@ func resourceArmMsSqlDatabaseCreateUpdate(d *schema.ResourceData, meta interface

future, err := client.CreateOrUpdate(ctx, serverId.ResourceGroup, serverId.Name, name, params)
if err != nil {
return fmt.Errorf("Failure in creating MsSql Database %q (Sql Server %q / Resource Group %q): %+v", name, serverId.Name, serverId.ResourceGroup, err)
return fmt.Errorf("creating MsSql Database %q (Sql Server %q / Resource Group %q): %+v", name, serverId.Name, serverId.ResourceGroup, err)
}

if err = future.WaitForCompletionRef(ctx, client.Client); err != nil {
return fmt.Errorf("Failure in waiting for creation of MsSql Database %q (MsSql Server Name %q / Resource Group %q): %+v", name, serverId.Name, serverId.ResourceGroup, err)
return fmt.Errorf("waiting for creation of MsSql Database %q (MsSql Server Name %q / Resource Group %q): %+v", name, serverId.Name, serverId.ResourceGroup, err)
}

read, err := client.Get(ctx, serverId.ResourceGroup, serverId.Name, name)
if err != nil {
return fmt.Errorf("Failure in retrieving MsSql Database %q (MsSql Server Name %q / Resource Group %q): %+v", name, serverId.Name, serverId.ResourceGroup, err)
return fmt.Errorf("retrieving MsSql Database %q (MsSql Server Name %q / Resource Group %q): %+v", name, serverId.Name, serverId.ResourceGroup, err)
}

if read.ID == nil || *read.ID == "" {
Expand All @@ -290,35 +378,40 @@ func resourceArmMsSqlDatabaseCreateUpdate(d *schema.ResourceData, meta interface

d.SetId(*read.ID)

if _, err = threatClient.CreateOrUpdate(ctx, serverId.ResourceGroup, serverId.Name, name, *expandArmMsSqlServerThreatDetectionPolicy(d, location)); err != nil {
return fmt.Errorf("setting database threat detection policy: %+v", err)
}

return resourceArmMsSqlDatabaseRead(d, meta)
}

func resourceArmMsSqlDatabaseRead(d *schema.ResourceData, meta interface{}) error {
client := meta.(*clients.Client).MSSQL.DatabasesClient
threatClient := meta.(*clients.Client).MSSQL.DatabaseThreatDetectionPoliciesClient
ctx, cancel := timeouts.ForRead(meta.(*clients.Client).StopContext, d)
defer cancel()

databaseId, err := parse.MsSqlDatabaseID(d.Id())
id, err := parse.MsSqlDatabaseID(d.Id())
if err != nil {
return err
}

resp, err := client.Get(ctx, databaseId.ResourceGroup, databaseId.MsSqlServer, databaseId.Name)
resp, err := client.Get(ctx, id.ResourceGroup, id.MsSqlServer, id.Name)
if err != nil {
if utils.ResponseWasNotFound(resp.Response) {
d.SetId("")
return nil
}
return fmt.Errorf("Failure in reading MsSql Database %s (MsSql Server Name %q / Resource Group %q): %s", databaseId.Name, databaseId.MsSqlServer, databaseId.ResourceGroup, err)
return fmt.Errorf("reading MsSql Database %s (MsSql Server Name %q / Resource Group %q): %s", id.Name, id.MsSqlServer, id.ResourceGroup, err)
}

d.Set("name", resp.Name)

serverClient := meta.(*clients.Client).MSSQL.ServersClient

serverResp, err := serverClient.Get(ctx, databaseId.ResourceGroup, databaseId.MsSqlServer)
serverResp, err := serverClient.Get(ctx, id.ResourceGroup, id.MsSqlServer)
if err != nil || *serverResp.ID == "" {
return fmt.Errorf("Failure in making Read request on MsSql Server %q (Resource Group %q): %s", databaseId.MsSqlServer, databaseId.ResourceGroup, err)
return fmt.Errorf("making Read request on MsSql Server %q (Resource Group %q): %s", id.MsSqlServer, id.ResourceGroup, err)
}
d.Set("server_id", serverResp.ID)

Expand All @@ -341,6 +434,13 @@ func resourceArmMsSqlDatabaseRead(d *schema.ResourceData, meta interface{}) erro
d.Set("zone_redundant", props.ZoneRedundant)
}

threat, err := threatClient.Get(ctx, id.ResourceGroup, id.MsSqlServer, id.Name)
if err == nil {
if err := d.Set("threat_detection_policy", flattenArmMsSqlServerThreatDetectionPolicy(d, threat)); err != nil {
return fmt.Errorf("setting `threat_detection_policy`: %+v", err)
}
}

return tags.FlattenAndSet(d, resp.Tags)
}

Expand All @@ -356,15 +456,117 @@ func resourceArmMsSqlDatabaseDelete(d *schema.ResourceData, meta interface{}) er

future, err := client.Delete(ctx, id.ResourceGroup, id.MsSqlServer, id.Name)
if err != nil {
return fmt.Errorf("Failure in deleting MsSql Database %q ( MsSql Server %q / Resource Group %q): %+v", id.Name, id.MsSqlServer, id.ResourceGroup, err)
return fmt.Errorf("deleting MsSql Database %q ( MsSql Server %q / Resource Group %q): %+v", id.Name, id.MsSqlServer, id.ResourceGroup, err)
}

if err = future.WaitForCompletionRef(ctx, client.Client); err != nil {
if response.WasNotFound(future.Response()) {
return nil
}
return fmt.Errorf("Failure in waiting for MsSql Database %q ( MsSql Server %q / Resource Group %q) to be deleted: %+v", id.Name, id.MsSqlServer, id.ResourceGroup, err)
return fmt.Errorf("waiting for MsSql Database %q ( MsSql Server %q / Resource Group %q) to be deleted: %+v", id.Name, id.MsSqlServer, id.ResourceGroup, err)
}

return nil
}

func flattenArmMsSqlServerThreatDetectionPolicy(d *schema.ResourceData, policy sql.DatabaseSecurityAlertPolicy) []interface{} {
// The SQL database threat detection API always returns the default value even if never set.
// If the values are on their default one, threat it as not set.
properties := policy.DatabaseSecurityAlertPolicyProperties
if properties == nil {
return []interface{}{}
}

threatDetectionPolicy := make(map[string]interface{})

threatDetectionPolicy["state"] = string(properties.State)
threatDetectionPolicy["email_account_admins"] = string(properties.EmailAccountAdmins)
threatDetectionPolicy["use_server_default"] = string(properties.UseServerDefault)

if disabledAlerts := properties.DisabledAlerts; disabledAlerts != nil {
flattenedAlerts := schema.NewSet(schema.HashString, []interface{}{})
if v := *disabledAlerts; v != "" {
parsedAlerts := strings.Split(v, ";")
for _, a := range parsedAlerts {
flattenedAlerts.Add(a)
}
}
threatDetectionPolicy["disabled_alerts"] = flattenedAlerts
}
if emailAddresses := properties.EmailAddresses; emailAddresses != nil {
flattenedEmails := schema.NewSet(schema.HashString, []interface{}{})
if v := *emailAddresses; v != "" {
parsedEmails := strings.Split(*emailAddresses, ";")
for _, e := range parsedEmails {
flattenedEmails.Add(e)
}
}
threatDetectionPolicy["email_addresses"] = flattenedEmails
}
if properties.StorageEndpoint != nil {
threatDetectionPolicy["storage_endpoint"] = *properties.StorageEndpoint
}
if properties.RetentionDays != nil {
threatDetectionPolicy["retention_days"] = int(*properties.RetentionDays)
}

// If storage account access key is in state read it to the new state, as the API does not return it for security reasons
if v, ok := d.GetOk("threat_detection_policy.0.storage_account_access_key"); ok {
threatDetectionPolicy["storage_account_access_key"] = v.(string)
}

return []interface{}{threatDetectionPolicy}
}

func expandArmMsSqlServerThreatDetectionPolicy(d *schema.ResourceData, location string) *sql.DatabaseSecurityAlertPolicy {
policy := sql.DatabaseSecurityAlertPolicy{
Location: utils.String(location),
DatabaseSecurityAlertPolicyProperties: &sql.DatabaseSecurityAlertPolicyProperties{
State: sql.SecurityAlertPolicyStateDisabled,
},
}
properties := policy.DatabaseSecurityAlertPolicyProperties

td, ok := d.GetOk("threat_detection_policy")
if !ok {
return &policy
}

if tdl := td.([]interface{}); len(tdl) > 0 {
threatDetection := tdl[0].(map[string]interface{})

properties.State = sql.SecurityAlertPolicyState(threatDetection["state"].(string))
properties.EmailAccountAdmins = sql.SecurityAlertPolicyEmailAccountAdmins(threatDetection["email_account_admins"].(string))
properties.UseServerDefault = sql.SecurityAlertPolicyUseServerDefault(threatDetection["use_server_default"].(string))

if v, ok := threatDetection["disabled_alerts"]; ok {
alerts := v.(*schema.Set).List()
expandedAlerts := make([]string, len(alerts))
for i, a := range alerts {
expandedAlerts[i] = a.(string)
}
properties.DisabledAlerts = utils.String(strings.Join(expandedAlerts, ";"))
}
if v, ok := threatDetection["email_addresses"]; ok {
emails := v.(*schema.Set).List()
expandedEmails := make([]string, len(emails))
for i, e := range emails {
expandedEmails[i] = e.(string)
}
properties.EmailAddresses = utils.String(strings.Join(expandedEmails, ";"))
}
if v, ok := threatDetection["retention_days"]; ok {
properties.RetentionDays = utils.Int32(int32(v.(int)))
}
if v, ok := threatDetection["storage_account_access_key"]; ok {
properties.StorageAccountAccessKey = utils.String(v.(string))
}
if v, ok := threatDetection["storage_endpoint"]; ok {
properties.StorageEndpoint = utils.String(v.(string))
}

return &policy
}

return &policy
}
Loading