Skip to content

Commit

Permalink
Add support partitions in policy data sources
Browse files Browse the repository at this point in the history
  • Loading branch information
ashenm committed Nov 3, 2024
1 parent 28b8f49 commit 0c45e23
Show file tree
Hide file tree
Showing 13 changed files with 187 additions and 31 deletions.
4 changes: 4 additions & 0 deletions aws/constants.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package aws

var AwsPartitions = []string{"aws", "aws-cn", "aws-us-gov"}
var AwsPartitionsValidationError = "aws_partition must be either 'aws', 'aws-cn', or 'aws-us-gov'"
13 changes: 10 additions & 3 deletions aws/data_aws_bucket_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ func DataAwsBucketPolicy() common.Resource {
return common.Resource{
Read: func(ctx context.Context, d *schema.ResourceData, c *common.DatabricksClient) error {
bucket := d.Get("bucket").(string)
awsPartition := d.Get("aws_partition").(string)
policy := awsIamPolicy{
Version: "2012-10-17",
Statements: []*awsIamPolicyStatement{
Expand All @@ -30,11 +31,11 @@ func DataAwsBucketPolicy() common.Resource {
"s3:GetBucketLocation",
},
Resources: []string{
fmt.Sprintf("arn:aws:s3:::%s/*", bucket),
fmt.Sprintf("arn:aws:s3:::%s", bucket),
fmt.Sprintf("arn:%s:s3:::%s/*", awsPartition, bucket),
fmt.Sprintf("arn:%s:s3:::%s", awsPartition, bucket),
},
Principal: map[string]string{
"AWS": fmt.Sprintf("arn:aws:iam::%s:root", d.Get("databricks_account_id").(string)),
"AWS": fmt.Sprintf("arn:%s:iam::%s:root", awsPartition, d.Get("databricks_account_id").(string)),
},
},
},
Expand All @@ -60,6 +61,12 @@ func DataAwsBucketPolicy() common.Resource {
return nil
},
Schema: map[string]*schema.Schema{
"aws_partition": {
Type: schema.TypeString,
Optional: true,
ValidateFunc: validation.StringInSlice(AwsPartitions, false),
Default: "aws",
},
"databricks_account_id": {
Type: schema.TypeString,
Default: "414351767826",
Expand Down
16 changes: 16 additions & 0 deletions aws/data_aws_bucket_policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,19 @@ func TestDataAwsBucketPolicyConfusedDeputyProblem(t *testing.T) {
j := d.Get("json")
assert.Lenf(t, j, 575, "Strange length for policy: %s", j)
}

func TestDataAwsBucketPolicyPartitionGov(t *testing.T) {
d, err := qa.ResourceFixture{
Read: true,
Resource: DataAwsBucketPolicy(),
NonWritable: true,
ID: ".",
HCL: `
bucket = "abc"
aws_partition = "aws-us-gov"
`,
}.Apply(t)
assert.NoError(t, err)
j := d.Get("json")
assert.Lenf(t, j, 461, "Strange length for policy: %s", j)
}
49 changes: 28 additions & 21 deletions aws/data_aws_crossaccount_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package aws
import (
"context"
"encoding/json"
"errors"
"fmt"
"regexp"
"slices"
Expand All @@ -17,11 +18,16 @@ func DataAwsCrossaccountPolicy() common.Resource {
PassRole []string `json:"pass_roles,omitempty"`
JSON string `json:"json" tf:"computed"`
AwsAccountId string `json:"aws_account_id,omitempty"`
AwsPartition string `json:"aws_partition,omitempty" tf:"default:aws"`
VpcId string `json:"vpc_id,omitempty"`
Region string `json:"region,omitempty"`
SecurityGroupId string `json:"security_group_id,omitempty"`
}
return common.NoClientData(func(ctx context.Context, data *AwsCrossAccountPolicy) error {
if !slices.Contains(AwsPartitions, data.AwsPartition) {
return errors.New(AwsPartitionsValidationError)
}

if !slices.Contains([]string{"managed", "customer", "restricted"}, data.PolicyType) {
return fmt.Errorf("policy_type must be either 'managed', 'customer' or 'restricted'")
}
Expand Down Expand Up @@ -168,6 +174,7 @@ func DataAwsCrossaccountPolicy() common.Resource {
if data.PolicyType == "restricted" {
region := data.Region
aws_account_id := data.AwsAccountId
awsPartition := data.AwsPartition
vpc_id := data.VpcId
security_group_id := data.SecurityGroupId
policy.Statements = append(policy.Statements,
Expand All @@ -179,7 +186,7 @@ func DataAwsCrossaccountPolicy() common.Resource {
"ec2:DisassociateIamInstanceProfile",
"ec2:ReplaceIamInstanceProfileAssociation",
},
Resources: fmt.Sprintf("arn:aws:ec2:%s:%s:instance/*", region, aws_account_id),
Resources: fmt.Sprintf("arn:%s:ec2:%s:%s:instance/*", awsPartition, region, aws_account_id),
Condition: map[string]map[string]string{
"StringEquals": {
"ec2:ResourceTag/Vendor": "Databricks",
Expand All @@ -191,8 +198,8 @@ func DataAwsCrossaccountPolicy() common.Resource {
Effect: "Allow",
Actions: "ec2:RunInstances",
Resources: []string{
fmt.Sprintf("arn:aws:ec2:%s:%s:volume/*", region, aws_account_id),
fmt.Sprintf("arn:aws:ec2:%s:%s:instance/*", region, aws_account_id),
fmt.Sprintf("arn:%s:ec2:%s:%s:volume/*", awsPartition, region, aws_account_id),
fmt.Sprintf("arn:%s:ec2:%s:%s:instance/*", awsPartition, region, aws_account_id),
},
Condition: map[string]map[string]string{
"StringEquals": {
Expand All @@ -204,7 +211,7 @@ func DataAwsCrossaccountPolicy() common.Resource {
Sid: "AllowEc2RunInstanceImagePerTag",
Effect: "Allow",
Actions: "ec2:RunInstances",
Resources: fmt.Sprintf("arn:aws:ec2:%s:%s:image/*", region, aws_account_id),
Resources: fmt.Sprintf("arn:%s:ec2:%s:%s:image/*", awsPartition, region, aws_account_id),
Condition: map[string]map[string]string{
"StringEquals": {
"aws:ResourceTag/Vendor": "Databricks",
Expand All @@ -216,13 +223,13 @@ func DataAwsCrossaccountPolicy() common.Resource {
Effect: "Allow",
Actions: "ec2:RunInstances",
Resources: []string{
fmt.Sprintf("arn:aws:ec2:%s:%s:network-interface/*", region, aws_account_id),
fmt.Sprintf("arn:aws:ec2:%s:%s:subnet/*", region, aws_account_id),
fmt.Sprintf("arn:aws:ec2:%s:%s:security-group/*", region, aws_account_id),
fmt.Sprintf("arn:%s:ec2:%s:%s:network-interface/*", awsPartition, region, aws_account_id),
fmt.Sprintf("arn:%s:ec2:%s:%s:subnet/*", awsPartition, region, aws_account_id),
fmt.Sprintf("arn:%s:ec2:%s:%s:security-group/*", awsPartition, region, aws_account_id),
},
Condition: map[string]map[string]string{
"StringEquals": {
"ec2:vpc": fmt.Sprintf("arn:aws:ec2:%s:%s:vpc/%s", region, aws_account_id, vpc_id),
"ec2:vpc": fmt.Sprintf("arn:%s:ec2:%s:%s:vpc/%s", awsPartition, region, aws_account_id, vpc_id),
},
},
},
Expand All @@ -231,19 +238,19 @@ func DataAwsCrossaccountPolicy() common.Resource {
Effect: "Allow",
Actions: "ec2:RunInstances",
NotResources: []string{
fmt.Sprintf("arn:aws:ec2:%s:%s:image/*", region, aws_account_id),
fmt.Sprintf("arn:aws:ec2:%s:%s:network-interface/*", region, aws_account_id),
fmt.Sprintf("arn:aws:ec2:%s:%s:subnet/*", region, aws_account_id),
fmt.Sprintf("arn:aws:ec2:%s:%s:security-group/*", region, aws_account_id),
fmt.Sprintf("arn:aws:ec2:%s:%s:volume/*", region, aws_account_id),
fmt.Sprintf("arn:aws:ec2:%s:%s:instance/*", region, aws_account_id),
fmt.Sprintf("arn:%s:ec2:%s:%s:image/*", awsPartition, region, aws_account_id),
fmt.Sprintf("arn:%s:ec2:%s:%s:network-interface/*", awsPartition, region, aws_account_id),
fmt.Sprintf("arn:%s:ec2:%s:%s:subnet/*", awsPartition, region, aws_account_id),
fmt.Sprintf("arn:%s:ec2:%s:%s:security-group/*", awsPartition, region, aws_account_id),
fmt.Sprintf("arn:%s:ec2:%s:%s:volume/*", awsPartition, region, aws_account_id),
fmt.Sprintf("arn:%s:ec2:%s:%s:instance/*", awsPartition, region, aws_account_id),
},
},
&awsIamPolicyStatement{
Sid: "EC2TerminateInstancesTag",
Effect: "Allow",
Actions: "ec2:TerminateInstances",
Resources: fmt.Sprintf("arn:aws:ec2:%s:%s:instance/*", region, aws_account_id),
Resources: fmt.Sprintf("arn:%s:ec2:%s:%s:instance/*", awsPartition, region, aws_account_id),
Condition: map[string]map[string]string{
"StringEquals": {
"ec2:ResourceTag/Vendor": "Databricks",
Expand All @@ -258,8 +265,8 @@ func DataAwsCrossaccountPolicy() common.Resource {
"ec2:DetachVolume",
},
Resources: []string{
fmt.Sprintf("arn:aws:ec2:%s:%s:instance/*", region, aws_account_id),
fmt.Sprintf("arn:aws:ec2:%s:%s:volume/*", region, aws_account_id),
fmt.Sprintf("arn:%s:ec2:%s:%s:instance/*", awsPartition, region, aws_account_id),
fmt.Sprintf("arn:%s:ec2:%s:%s:volume/*", awsPartition, region, aws_account_id),
},
Condition: map[string]map[string]string{
"StringEquals": {
Expand All @@ -271,7 +278,7 @@ func DataAwsCrossaccountPolicy() common.Resource {
Sid: "EC2CreateVolumeByTag",
Effect: "Allow",
Actions: "ec2:CreateVolume",
Resources: fmt.Sprintf("arn:aws:ec2:%s:%s:volume/*", region, aws_account_id),
Resources: fmt.Sprintf("arn:%s:ec2:%s:%s:volume/*", awsPartition, region, aws_account_id),
Condition: map[string]map[string]string{
"StringEquals": {
"aws:RequestTag/Vendor": "Databricks",
Expand All @@ -283,7 +290,7 @@ func DataAwsCrossaccountPolicy() common.Resource {
Effect: "Allow",
Actions: "ec2:DeleteVolume",
Resources: []string{
fmt.Sprintf("arn:aws:ec2:%s:%s:volume/*", region, aws_account_id),
fmt.Sprintf("arn:%s:ec2:%s:%s:volume/*", awsPartition, region, aws_account_id),
},
Condition: map[string]map[string]string{
"StringEquals": {
Expand All @@ -300,10 +307,10 @@ func DataAwsCrossaccountPolicy() common.Resource {
"ec2:RevokeSecurityGroupEgress",
"ec2:RevokeSecurityGroupIngress",
},
Resources: fmt.Sprintf("arn:aws:ec2:%s:%s:security-group/%s", region, aws_account_id, security_group_id),
Resources: fmt.Sprintf("arn:%s:ec2:%s:%s:security-group/%s", awsPartition, region, aws_account_id, security_group_id),
Condition: map[string]map[string]string{
"StringEquals": {
"ec2:vpc": fmt.Sprintf("arn:aws:ec2:%s:%s:vpc/%s", region, aws_account_id, vpc_id),
"ec2:vpc": fmt.Sprintf("arn:%s:ec2:%s:%s:vpc/%s", awsPartition, region, aws_account_id, vpc_id),
},
},
},
Expand Down
29 changes: 29 additions & 0 deletions aws/data_aws_crossaccount_policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,25 @@ func TestDataAwsCrossAccountRestrictedPolicy(t *testing.T) {
assert.Lenf(t, j, 5725, "Strange length for policy: %s", j)
}

func TestDataAwsCrossAccountRestrictedPolicyPartitionGov(t *testing.T) {
d, err := qa.ResourceFixture{
Read: true,
Resource: DataAwsCrossaccountPolicy(),
NonWritable: true,
HCL: `
policy_type = "restricted"
aws_account_id = "123456789012"
aws_partition = "aws-us-gov"
vpc_id = "vpc-12345678"
region = "us-west-2"
security_group_id = "sg-12345678"`,
ID: ".",
}.Apply(t)
assert.NoError(t, err)
j := d.Get("json")
assert.Lenf(t, j, 5872, "Strange length for policy: %s", j)
}

func TestDataAwsCrossAccountInvalidPolicy(t *testing.T) {
qa.ResourceFixture{
Read: true,
Expand All @@ -552,6 +571,16 @@ func TestDataAwsCrossAccountInvalidAccountId(t *testing.T) {
}.ExpectError(t, "aws_account_id must be a 12 digit number")
}

func TestDataAwsCrossAccountInvalidPartition(t *testing.T) {
qa.ResourceFixture{
Read: true,
Resource: DataAwsCrossaccountPolicy(),
NonWritable: true,
HCL: `aws_partition = "something"`,
ID: ".",
}.ExpectError(t, AwsPartitionsValidationError)
}

func TestDataAwsCrossAccountInvalidVpcId(t *testing.T) {
qa.ResourceFixture{
Read: true,
Expand Down
12 changes: 10 additions & 2 deletions aws/data_aws_unity_catalog_assume_role_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package aws
import (
"context"
"encoding/json"
"errors"
"fmt"
"slices"

"github.com/databricks/terraform-provider-databricks/common"
)
Expand All @@ -14,13 +16,19 @@ func DataAwsUnityCatalogAssumeRolePolicy() common.Resource {
UnityCatalogIamArn string `json:"unity_catalog_iam_arn,omitempty" tf:"computed"`
ExternalId string `json:"external_id"`
AwsAccountId string `json:"aws_account_id"`
AwsPartition string `json:"aws_partition,omitempty" tf:"default:aws"`
JSON string `json:"json" tf:"computed"`
Id string `json:"id" tf:"computed"`
}
return common.NoClientData(func(ctx context.Context, data *AwsUcAssumeRolePolicy) error {
if data.UnityCatalogIamArn == "" {
data.UnityCatalogIamArn = "arn:aws:iam::414351767826:role/unity-catalog-prod-UCMasterRole-14S5ZJVKOTYTL"
}

if !slices.Contains(AwsPartitions, data.AwsPartition) {
return errors.New(AwsPartitionsValidationError)
}

policy := awsIamPolicy{
Version: "2012-10-17",
Statements: []*awsIamPolicyStatement{
Expand All @@ -43,11 +51,11 @@ func DataAwsUnityCatalogAssumeRolePolicy() common.Resource {
Actions: "sts:AssumeRole",
Condition: map[string]map[string]string{
"ArnLike": {
"aws:PrincipalArn": fmt.Sprintf("arn:aws:iam::%s:role/%s", data.AwsAccountId, data.RoleName),
"aws:PrincipalArn": fmt.Sprintf("arn:%s:iam::%s:role/%s", data.AwsPartition, data.AwsAccountId, data.RoleName),
},
},
Principal: map[string]string{
"AWS": fmt.Sprintf("arn:aws:iam::%s:root", data.AwsAccountId),
"AWS": fmt.Sprintf("arn:%s:iam::%s:root", data.AwsPartition, data.AwsAccountId),
},
},
},
Expand Down
16 changes: 16 additions & 0 deletions aws/data_aws_unity_catalog_assume_role_policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,19 @@ func TestDataAwsUnityCatalogAssumeRolePolicyWithoutUcArn(t *testing.T) {
}`
compareJSON(t, j, p)
}

func TestDataAwsUnityCatalogAssumeRolePolicyInvalidPartition(t *testing.T) {
qa.ResourceFixture{
Read: true,
Resource: DataAwsUnityCatalogAssumeRolePolicy(),
NonWritable: true,
ID: ".",
HCL: `
aws_account_id = "123456789098"
aws_partition = "something"
role_name = "databricks-role"
unity_catalog_iam_arn = "arn:aws:iam::414351767826:role/unity-catalog-prod-UCMasterRole-14S5ZJVKOTYTL"
external_id = "12345"
`,
}.ExpectError(t, AwsPartitionsValidationError)
}
17 changes: 12 additions & 5 deletions aws/data_aws_unity_catalog_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
func generateReadContext(ctx context.Context, d *schema.ResourceData, m *common.DatabricksClient) error {
bucket := d.Get("bucket_name").(string)
awsAccountId := d.Get("aws_account_id").(string)
awsPartition := d.Get("aws_partition").(string)
roleName := d.Get("role_name").(string)
policy := awsIamPolicy{
Version: "2012-10-17",
Expand All @@ -29,8 +30,8 @@ func generateReadContext(ctx context.Context, d *schema.ResourceData, m *common.
"s3:GetBucketLocation",
},
Resources: []string{
fmt.Sprintf("arn:aws:s3:::%s/*", bucket),
fmt.Sprintf("arn:aws:s3:::%s", bucket),
fmt.Sprintf("arn:%s:s3:::%s/*", awsPartition, bucket),
fmt.Sprintf("arn:%s:s3:::%s", awsPartition, bucket),
},
},
{
Expand All @@ -39,14 +40,14 @@ func generateReadContext(ctx context.Context, d *schema.ResourceData, m *common.
"sts:AssumeRole",
},
Resources: []string{
fmt.Sprintf("arn:aws:iam::%s:role/%s", awsAccountId, roleName),
fmt.Sprintf("arn:%s:iam::%s:role/%s", awsPartition, awsAccountId, roleName),
},
},
},
}
if kmsKey, ok := d.GetOk("kms_name"); ok {
kmsArn := fmt.Sprintf("arn:aws:kms:%s", kmsKey)
if strings.HasPrefix(kmsKey.(string), "arn:aws") {
kmsArn := fmt.Sprintf("arn:%s:kms:%s", awsPartition, kmsKey)
if strings.HasPrefix(kmsKey.(string), fmt.Sprintf("arn:%s", awsPartition)) {
kmsArn = kmsKey.(string)
}
policy.Statements = append(policy.Statements, &awsIamPolicyStatement{
Expand Down Expand Up @@ -92,6 +93,12 @@ func validateSchema() map[string]*schema.Schema {
Type: schema.TypeString,
Required: true,
},
"aws_partition": {
Type: schema.TypeString,
Optional: true,
ValidateFunc: validation.StringInSlice(AwsPartitions, false),
Default: "aws",
},
"json": {
Type: schema.TypeString,
Computed: true,
Expand Down
Loading

0 comments on commit 0c45e23

Please sign in to comment.