Skip to content

Commit

Permalink
mssql_database - support the threat_detection_policy property (#6437)
Browse files Browse the repository at this point in the history
  • Loading branch information
katbyte authored Apr 12, 2020
1 parent fba5506 commit aeb1d21
Show file tree
Hide file tree
Showing 7 changed files with 337 additions and 45 deletions.
47 changes: 26 additions & 21 deletions azurerm/internal/services/mssql/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

type Client struct {
DatabasesClient *sql.DatabasesClient
DatabaseThreatDetectionPoliciesClient *sql.DatabaseThreatDetectionPoliciesClient
ElasticPoolsClient *sql.ElasticPoolsClient
DatabaseVulnerabilityAssessmentRuleBaselinesClient *sql.DatabaseVulnerabilityAssessmentRuleBaselinesClient
ServersClient *sql.ServersClient
Expand All @@ -17,34 +18,38 @@ type Client struct {
}

func NewClient(o *common.ClientOptions) *Client {
DatabasesClient := sql.NewDatabasesClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&DatabasesClient.Client, o.ResourceManagerAuthorizer)
databasesClient := sql.NewDatabasesClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&databasesClient.Client, o.ResourceManagerAuthorizer)

ElasticPoolsClient := sql.NewElasticPoolsClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&ElasticPoolsClient.Client, o.ResourceManagerAuthorizer)
databaseThreatDetectionPoliciesClient := sql.NewDatabaseThreatDetectionPoliciesClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&databaseThreatDetectionPoliciesClient.Client, o.ResourceManagerAuthorizer)

DatabaseVulnerabilityAssessmentRuleBaselinesClient := sql.NewDatabaseVulnerabilityAssessmentRuleBaselinesClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&DatabaseVulnerabilityAssessmentRuleBaselinesClient.Client, o.ResourceManagerAuthorizer)
databaseVulnerabilityAssessmentRuleBaselinesClient := sql.NewDatabaseVulnerabilityAssessmentRuleBaselinesClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&databaseVulnerabilityAssessmentRuleBaselinesClient.Client, o.ResourceManagerAuthorizer)

ServerSecurityAlertPoliciesClient := sql.NewServerSecurityAlertPoliciesClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&ServerSecurityAlertPoliciesClient.Client, o.ResourceManagerAuthorizer)
elasticPoolsClient := sql.NewElasticPoolsClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&elasticPoolsClient.Client, o.ResourceManagerAuthorizer)

ServersClient := sql.NewServersClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&ServersClient.Client, o.ResourceManagerAuthorizer)
serverSecurityAlertPoliciesClient := sql.NewServerSecurityAlertPoliciesClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&serverSecurityAlertPoliciesClient.Client, o.ResourceManagerAuthorizer)

ServerVulnerabilityAssessmentsClient := sql.NewServerVulnerabilityAssessmentsClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&ServerVulnerabilityAssessmentsClient.Client, o.ResourceManagerAuthorizer)
serversClient := sql.NewServersClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&serversClient.Client, o.ResourceManagerAuthorizer)

SQLVirtualMachinesClient := sqlvirtualmachine.NewSQLVirtualMachinesClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&SQLVirtualMachinesClient.Client, o.ResourceManagerAuthorizer)
serverVulnerabilityAssessmentsClient := sql.NewServerVulnerabilityAssessmentsClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&serverVulnerabilityAssessmentsClient.Client, o.ResourceManagerAuthorizer)

sqlVirtualMachinesClient := sqlvirtualmachine.NewSQLVirtualMachinesClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&sqlVirtualMachinesClient.Client, o.ResourceManagerAuthorizer)

return &Client{
DatabasesClient: &DatabasesClient,
ElasticPoolsClient: &ElasticPoolsClient,
DatabaseVulnerabilityAssessmentRuleBaselinesClient: &DatabaseVulnerabilityAssessmentRuleBaselinesClient,
ServersClient: &ServersClient,
ServerSecurityAlertPoliciesClient: &ServerSecurityAlertPoliciesClient,
ServerVulnerabilityAssessmentsClient: &ServerVulnerabilityAssessmentsClient,
VirtualMachinesClient: &SQLVirtualMachinesClient,
DatabasesClient: &databasesClient,
DatabaseThreatDetectionPoliciesClient: &databaseThreatDetectionPoliciesClient,
DatabaseVulnerabilityAssessmentRuleBaselinesClient: &databaseVulnerabilityAssessmentRuleBaselinesClient,
ElasticPoolsClient: &elasticPoolsClient,
ServersClient: &serversClient,
ServerSecurityAlertPoliciesClient: &serverSecurityAlertPoliciesClient,
ServerVulnerabilityAssessmentsClient: &serverVulnerabilityAssessmentsClient,
VirtualMachinesClient: &sqlVirtualMachinesClient,
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func dataSourceArmMsSqlDatabaseRead(d *schema.ResourceData, meta interface{}) er
return fmt.Errorf("Database %q (Resource Group %q, SQL Server %q) was not found", name, serverId.ResourceGroup, serverId.Name)
}

return fmt.Errorf("Failure in making Read request on AzureRM Database %s (Resource Group %q, SQL Server %q): %+v", name, serverId.ResourceGroup, serverId.Name, err)
return fmt.Errorf("making Read request on AzureRM Database %s (Resource Group %q, SQL Server %q): %+v", name, serverId.ResourceGroup, serverId.Name, err)
}

if id := resp.ID; id != nil {
Expand Down
224 changes: 213 additions & 11 deletions azurerm/internal/services/mssql/resource_arm_mssql_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package mssql
import (
"fmt"
"log"
"strings"
"time"

"github.com/Azure/azure-sdk-for-go/services/preview/sql/mgmt/v3.0/sql"
Expand Down Expand Up @@ -176,6 +177,92 @@ func resourceArmMsSqlDatabase() *schema.Resource {
Computed: true,
},

"threat_detection_policy": {
Type: schema.TypeList,
Optional: true,
Computed: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"disabled_alerts": {
Type: schema.TypeSet,
Optional: true,
Set: schema.HashString,
Elem: &schema.Schema{
Type: schema.TypeString,
ValidateFunc: validation.StringInSlice([]string{
"Sql_Injection",
"Sql_Injection_Vulnerability",
"Access_Anomaly",
}, true),
},
},

"email_account_admins": {
Type: schema.TypeString,
Optional: true,
DiffSuppressFunc: suppress.CaseDifference,
Default: string(sql.SecurityAlertPolicyEmailAccountAdminsDisabled),
ValidateFunc: validation.StringInSlice([]string{
string(sql.SecurityAlertPolicyEmailAccountAdminsDisabled),
string(sql.SecurityAlertPolicyEmailAccountAdminsEnabled),
}, true),
},

"email_addresses": {
Type: schema.TypeSet,
Optional: true,
Elem: &schema.Schema{
Type: schema.TypeString,
},
Set: schema.HashString,
},

"retention_days": {
Type: schema.TypeInt,
Optional: true,
ValidateFunc: validation.IntAtLeast(0),
},

"state": {
Type: schema.TypeString,
Optional: true,
DiffSuppressFunc: suppress.CaseDifference,
Default: string(sql.SecurityAlertPolicyStateDisabled),
ValidateFunc: validation.StringInSlice([]string{
string(sql.SecurityAlertPolicyStateDisabled),
string(sql.SecurityAlertPolicyStateEnabled),
string(sql.SecurityAlertPolicyStateNew),
}, true),
},

"storage_account_access_key": {
Type: schema.TypeString,
Optional: true,
Sensitive: true,
ValidateFunc: validation.StringIsNotEmpty,
},

"storage_endpoint": {
Type: schema.TypeString,
Optional: true,
ValidateFunc: validation.StringIsNotEmpty,
},

"use_server_default": {
Type: schema.TypeString,
Optional: true,
DiffSuppressFunc: suppress.CaseDifference,
Default: string(sql.SecurityAlertPolicyUseServerDefaultDisabled),
ValidateFunc: validation.StringInSlice([]string{
string(sql.SecurityAlertPolicyUseServerDefaultDisabled),
string(sql.SecurityAlertPolicyUseServerDefaultEnabled),
}, true),
},
},
},
},

"tags": tags.Schema(),
},
}
Expand All @@ -184,6 +271,7 @@ func resourceArmMsSqlDatabase() *schema.Resource {
func resourceArmMsSqlDatabaseCreateUpdate(d *schema.ResourceData, meta interface{}) error {
client := meta.(*clients.Client).MSSQL.DatabasesClient
serverClient := meta.(*clients.Client).MSSQL.ServersClient
threatClient := meta.(*clients.Client).MSSQL.DatabaseThreatDetectionPoliciesClient
ctx, cancel := timeouts.ForCreateUpdate(meta.(*clients.Client).StopContext, d)
defer cancel()

Expand All @@ -208,7 +296,7 @@ func resourceArmMsSqlDatabaseCreateUpdate(d *schema.ResourceData, meta interface

serverResp, err := serverClient.Get(ctx, serverId.ResourceGroup, serverId.Name)
if err != nil {
return fmt.Errorf("Failure in making Read request on MsSql Server %q (Resource Group %q): %s", serverId.Name, serverId.ResourceGroup, err)
return fmt.Errorf("making Read request on MsSql Server %q (Resource Group %q): %s", serverId.Name, serverId.ResourceGroup, err)
}

location := *serverResp.Location
Expand Down Expand Up @@ -272,16 +360,16 @@ func resourceArmMsSqlDatabaseCreateUpdate(d *schema.ResourceData, meta interface

future, err := client.CreateOrUpdate(ctx, serverId.ResourceGroup, serverId.Name, name, params)
if err != nil {
return fmt.Errorf("Failure in creating MsSql Database %q (Sql Server %q / Resource Group %q): %+v", name, serverId.Name, serverId.ResourceGroup, err)
return fmt.Errorf("creating MsSql Database %q (Sql Server %q / Resource Group %q): %+v", name, serverId.Name, serverId.ResourceGroup, err)
}

if err = future.WaitForCompletionRef(ctx, client.Client); err != nil {
return fmt.Errorf("Failure in waiting for creation of MsSql Database %q (MsSql Server Name %q / Resource Group %q): %+v", name, serverId.Name, serverId.ResourceGroup, err)
return fmt.Errorf("waiting for creation of MsSql Database %q (MsSql Server Name %q / Resource Group %q): %+v", name, serverId.Name, serverId.ResourceGroup, err)
}

read, err := client.Get(ctx, serverId.ResourceGroup, serverId.Name, name)
if err != nil {
return fmt.Errorf("Failure in retrieving MsSql Database %q (MsSql Server Name %q / Resource Group %q): %+v", name, serverId.Name, serverId.ResourceGroup, err)
return fmt.Errorf("retrieving MsSql Database %q (MsSql Server Name %q / Resource Group %q): %+v", name, serverId.Name, serverId.ResourceGroup, err)
}

if read.ID == nil || *read.ID == "" {
Expand All @@ -290,35 +378,40 @@ func resourceArmMsSqlDatabaseCreateUpdate(d *schema.ResourceData, meta interface

d.SetId(*read.ID)

if _, err = threatClient.CreateOrUpdate(ctx, serverId.ResourceGroup, serverId.Name, name, *expandArmMsSqlServerThreatDetectionPolicy(d, location)); err != nil {
return fmt.Errorf("setting database threat detection policy: %+v", err)
}

return resourceArmMsSqlDatabaseRead(d, meta)
}

func resourceArmMsSqlDatabaseRead(d *schema.ResourceData, meta interface{}) error {
client := meta.(*clients.Client).MSSQL.DatabasesClient
threatClient := meta.(*clients.Client).MSSQL.DatabaseThreatDetectionPoliciesClient
ctx, cancel := timeouts.ForRead(meta.(*clients.Client).StopContext, d)
defer cancel()

databaseId, err := parse.MsSqlDatabaseID(d.Id())
id, err := parse.MsSqlDatabaseID(d.Id())
if err != nil {
return err
}

resp, err := client.Get(ctx, databaseId.ResourceGroup, databaseId.MsSqlServer, databaseId.Name)
resp, err := client.Get(ctx, id.ResourceGroup, id.MsSqlServer, id.Name)
if err != nil {
if utils.ResponseWasNotFound(resp.Response) {
d.SetId("")
return nil
}
return fmt.Errorf("Failure in reading MsSql Database %s (MsSql Server Name %q / Resource Group %q): %s", databaseId.Name, databaseId.MsSqlServer, databaseId.ResourceGroup, err)
return fmt.Errorf("reading MsSql Database %s (MsSql Server Name %q / Resource Group %q): %s", id.Name, id.MsSqlServer, id.ResourceGroup, err)
}

d.Set("name", resp.Name)

serverClient := meta.(*clients.Client).MSSQL.ServersClient

serverResp, err := serverClient.Get(ctx, databaseId.ResourceGroup, databaseId.MsSqlServer)
serverResp, err := serverClient.Get(ctx, id.ResourceGroup, id.MsSqlServer)
if err != nil || *serverResp.ID == "" {
return fmt.Errorf("Failure in making Read request on MsSql Server %q (Resource Group %q): %s", databaseId.MsSqlServer, databaseId.ResourceGroup, err)
return fmt.Errorf("making Read request on MsSql Server %q (Resource Group %q): %s", id.MsSqlServer, id.ResourceGroup, err)
}
d.Set("server_id", serverResp.ID)

Expand All @@ -341,6 +434,13 @@ func resourceArmMsSqlDatabaseRead(d *schema.ResourceData, meta interface{}) erro
d.Set("zone_redundant", props.ZoneRedundant)
}

threat, err := threatClient.Get(ctx, id.ResourceGroup, id.MsSqlServer, id.Name)
if err == nil {
if err := d.Set("threat_detection_policy", flattenArmMsSqlServerThreatDetectionPolicy(d, threat)); err != nil {
return fmt.Errorf("setting `threat_detection_policy`: %+v", err)
}
}

return tags.FlattenAndSet(d, resp.Tags)
}

Expand All @@ -356,15 +456,117 @@ func resourceArmMsSqlDatabaseDelete(d *schema.ResourceData, meta interface{}) er

future, err := client.Delete(ctx, id.ResourceGroup, id.MsSqlServer, id.Name)
if err != nil {
return fmt.Errorf("Failure in deleting MsSql Database %q ( MsSql Server %q / Resource Group %q): %+v", id.Name, id.MsSqlServer, id.ResourceGroup, err)
return fmt.Errorf("deleting MsSql Database %q ( MsSql Server %q / Resource Group %q): %+v", id.Name, id.MsSqlServer, id.ResourceGroup, err)
}

if err = future.WaitForCompletionRef(ctx, client.Client); err != nil {
if response.WasNotFound(future.Response()) {
return nil
}
return fmt.Errorf("Failure in waiting for MsSql Database %q ( MsSql Server %q / Resource Group %q) to be deleted: %+v", id.Name, id.MsSqlServer, id.ResourceGroup, err)
return fmt.Errorf("waiting for MsSql Database %q ( MsSql Server %q / Resource Group %q) to be deleted: %+v", id.Name, id.MsSqlServer, id.ResourceGroup, err)
}

return nil
}

func flattenArmMsSqlServerThreatDetectionPolicy(d *schema.ResourceData, policy sql.DatabaseSecurityAlertPolicy) []interface{} {
// The SQL database threat detection API always returns the default value even if never set.
// If the values are on their default one, threat it as not set.
properties := policy.DatabaseSecurityAlertPolicyProperties
if properties == nil {
return []interface{}{}
}

threatDetectionPolicy := make(map[string]interface{})

threatDetectionPolicy["state"] = string(properties.State)
threatDetectionPolicy["email_account_admins"] = string(properties.EmailAccountAdmins)
threatDetectionPolicy["use_server_default"] = string(properties.UseServerDefault)

if disabledAlerts := properties.DisabledAlerts; disabledAlerts != nil {
flattenedAlerts := schema.NewSet(schema.HashString, []interface{}{})
if v := *disabledAlerts; v != "" {
parsedAlerts := strings.Split(v, ";")
for _, a := range parsedAlerts {
flattenedAlerts.Add(a)
}
}
threatDetectionPolicy["disabled_alerts"] = flattenedAlerts
}
if emailAddresses := properties.EmailAddresses; emailAddresses != nil {
flattenedEmails := schema.NewSet(schema.HashString, []interface{}{})
if v := *emailAddresses; v != "" {
parsedEmails := strings.Split(*emailAddresses, ";")
for _, e := range parsedEmails {
flattenedEmails.Add(e)
}
}
threatDetectionPolicy["email_addresses"] = flattenedEmails
}
if properties.StorageEndpoint != nil {
threatDetectionPolicy["storage_endpoint"] = *properties.StorageEndpoint
}
if properties.RetentionDays != nil {
threatDetectionPolicy["retention_days"] = int(*properties.RetentionDays)
}

// If storage account access key is in state read it to the new state, as the API does not return it for security reasons
if v, ok := d.GetOk("threat_detection_policy.0.storage_account_access_key"); ok {
threatDetectionPolicy["storage_account_access_key"] = v.(string)
}

return []interface{}{threatDetectionPolicy}
}

func expandArmMsSqlServerThreatDetectionPolicy(d *schema.ResourceData, location string) *sql.DatabaseSecurityAlertPolicy {
policy := sql.DatabaseSecurityAlertPolicy{
Location: utils.String(location),
DatabaseSecurityAlertPolicyProperties: &sql.DatabaseSecurityAlertPolicyProperties{
State: sql.SecurityAlertPolicyStateDisabled,
},
}
properties := policy.DatabaseSecurityAlertPolicyProperties

td, ok := d.GetOk("threat_detection_policy")
if !ok {
return &policy
}

if tdl := td.([]interface{}); len(tdl) > 0 {
threatDetection := tdl[0].(map[string]interface{})

properties.State = sql.SecurityAlertPolicyState(threatDetection["state"].(string))
properties.EmailAccountAdmins = sql.SecurityAlertPolicyEmailAccountAdmins(threatDetection["email_account_admins"].(string))
properties.UseServerDefault = sql.SecurityAlertPolicyUseServerDefault(threatDetection["use_server_default"].(string))

if v, ok := threatDetection["disabled_alerts"]; ok {
alerts := v.(*schema.Set).List()
expandedAlerts := make([]string, len(alerts))
for i, a := range alerts {
expandedAlerts[i] = a.(string)
}
properties.DisabledAlerts = utils.String(strings.Join(expandedAlerts, ";"))
}
if v, ok := threatDetection["email_addresses"]; ok {
emails := v.(*schema.Set).List()
expandedEmails := make([]string, len(emails))
for i, e := range emails {
expandedEmails[i] = e.(string)
}
properties.EmailAddresses = utils.String(strings.Join(expandedEmails, ";"))
}
if v, ok := threatDetection["retention_days"]; ok {
properties.RetentionDays = utils.Int32(int32(v.(int)))
}
if v, ok := threatDetection["storage_account_access_key"]; ok {
properties.StorageAccountAccessKey = utils.String(v.(string))
}
if v, ok := threatDetection["storage_endpoint"]; ok {
properties.StorageEndpoint = utils.String(v.(string))
}

return &policy
}

return &policy
}
Loading

0 comments on commit aeb1d21

Please sign in to comment.