Skip to content

Commit

Permalink
Merge pull request #25557 from neitomic/sagemaker-model-image-repo-auth
Browse files Browse the repository at this point in the history
s/sagemaker Add support repository_auth_config in image_config
  • Loading branch information
ewbankkit authored Jun 27, 2022
2 parents 7781977 + e8a931b commit 1558ba3
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 28 deletions.
3 changes: 3 additions & 0 deletions .changelog/25557.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:enhancement
resource/sagemaker: Add `repository_auth_config` arguments in support of [Private Docker Registry](https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-containers-inference-private.html)
```
70 changes: 68 additions & 2 deletions internal/service/sagemaker/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ func ResourceModel() *schema.Resource {
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.RepositoryAccessMode_Values(), false),
},
"repository_auth_config": {
Type: schema.TypeList,
Optional: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"repository_credentials_provider_arn": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: verify.ValidARN,
},
},
},
},
},
},
},
Expand Down Expand Up @@ -159,6 +174,21 @@ func ResourceModel() *schema.Resource {
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.RepositoryAccessMode_Values(), false),
},
"repository_auth_config": {
Type: schema.TypeList,
Optional: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"repository_credentials_provider_arn": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: verify.ValidARN,
},
},
},
},
},
},
},
Expand Down Expand Up @@ -408,8 +438,8 @@ func expandContainer(m map[string]interface{}) *sagemaker.ContainerDefinition {
if v, ok := m["model_data_url"]; ok && v.(string) != "" {
container.ModelDataUrl = aws.String(v.(string))
}
if v, ok := m["environment"]; ok {
container.Environment = flex.ExpandStringMap(v.(map[string]interface{}))
if v, ok := m["environment"].(map[string]interface{}); ok && len(v) > 0 {
container.Environment = flex.ExpandStringMap(v)
}

if v, ok := m["image_config"]; ok {
Expand All @@ -430,9 +460,27 @@ func expandModelImageConfig(l []interface{}) *sagemaker.ImageConfig {
RepositoryAccessMode: aws.String(m["repository_access_mode"].(string)),
}

if v, ok := m["repository_auth_config"].([]interface{}); ok && len(v) > 0 && v[0] != nil {
imageConfig.RepositoryAuthConfig = expandRepositoryAuthConfig(v[0].(map[string]interface{}))
}

return imageConfig
}

func expandRepositoryAuthConfig(tfMap map[string]interface{}) *sagemaker.RepositoryAuthConfig {
if tfMap == nil {
return nil
}

apiObject := &sagemaker.RepositoryAuthConfig{}

if v, ok := tfMap["repository_credentials_provider_arn"].(string); ok && v != "" {
apiObject.RepositoryCredentialsProviderArn = aws.String(v)
}

return apiObject
}

func expandContainers(a []interface{}) []*sagemaker.ContainerDefinition {
containers := make([]*sagemaker.ContainerDefinition, 0, len(a))

Expand Down Expand Up @@ -482,9 +530,27 @@ func flattenImageConfig(imageConfig *sagemaker.ImageConfig) []interface{} {

cfg["repository_access_mode"] = aws.StringValue(imageConfig.RepositoryAccessMode)

if tfMap := flattenRepositoryAuthConfig(imageConfig.RepositoryAuthConfig); len(tfMap) > 0 {
cfg["repository_auth_config"] = []interface{}{tfMap}
}

return []interface{}{cfg}
}

func flattenRepositoryAuthConfig(apiObject *sagemaker.RepositoryAuthConfig) map[string]interface{} {
if apiObject == nil {
return nil
}

tfMap := make(map[string]interface{})

if v := apiObject.RepositoryCredentialsProviderArn; v != nil {
tfMap["repository_credentials_provider_arn"] = aws.StringValue(v)
}

return tfMap
}

func flattenContainers(containers []*sagemaker.ContainerDefinition) []interface{} {
fContainers := make([]interface{}, 0, len(containers))
for _, container := range containers {
Expand Down
133 changes: 107 additions & 26 deletions internal/service/sagemaker/model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,34 @@ func TestAccSageMakerModel_vpc(t *testing.T) {
})
}

func TestAccSageMakerModel_primaryContainerPrivateDockerRegistry(t *testing.T) {
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
resourceName := "aws_sagemaker_model.test"

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { acctest.PreCheck(t) },
ErrorCheck: acctest.ErrorCheck(t, sagemaker.EndpointsID),
ProviderFactories: acctest.ProviderFactories,
CheckDestroy: testAccCheckModelDestroy,
Steps: []resource.TestStep{
{
Config: testAccModelConfig_primaryContainerPrivateDockerRegistry(rName),
Check: resource.ComposeTestCheckFunc(
testAccCheckModelExists(resourceName),
resource.TestCheckResourceAttr(resourceName, "primary_container.0.image_config.#", "1"),
resource.TestCheckResourceAttr(resourceName, "primary_container.0.image_config.0.repository_access_mode", "Vpc"),
resource.TestCheckResourceAttr(resourceName, "primary_container.0.image_config.0.repository_auth_config.0.repository_credentials_provider_arn", "arn:aws:lambda:us-east-2:123456789012:function:my-function:1"), //lintignore:AWSAT003,AWSAT005
),
},
{
ResourceName: resourceName,
ImportState: true,
ImportStateVerify: true,
},
},
})
}

func TestAccSageMakerModel_networkIsolation(t *testing.T) {
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
resourceName := "aws_sagemaker_model.test"
Expand Down Expand Up @@ -438,7 +466,7 @@ data "aws_sagemaker_prebuilt_ecr_image" "test" {
}

func testAccModelConfig_basic(rName string) string {
return testAccModelConfigBase(rName) + fmt.Sprintf(`
return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
Expand All @@ -447,11 +475,11 @@ resource "aws_sagemaker_model" "test" {
image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path
}
}
`, rName)
`, rName))
}

func testAccModelConfig_inferenceExecution(rName string) string {
return testAccModelConfigBase(rName) + fmt.Sprintf(`
return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
Expand All @@ -468,11 +496,11 @@ resource "aws_sagemaker_model" "test" {
image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path
}
}
`, rName)
`, rName))
}

func testAccModelConfig_tags1(rName, tagKey1, tagValue1 string) string {
return testAccModelConfigBase(rName) + fmt.Sprintf(`
return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
Expand All @@ -485,11 +513,11 @@ resource "aws_sagemaker_model" "test" {
%[2]q = %[3]q
}
}
`, rName, tagKey1, tagValue1)
`, rName, tagKey1, tagValue1))
}

func testAccModelConfig_tags2(rName, tagKey1, tagValue1, tagKey2, tagValue2 string) string {
return testAccModelConfigBase(rName) + fmt.Sprintf(`
return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
Expand All @@ -503,11 +531,11 @@ resource "aws_sagemaker_model" "test" {
%[4]q = %[5]q
}
}
`, rName, tagKey1, tagValue1, tagKey2, tagValue2)
`, rName, tagKey1, tagValue1, tagKey2, tagValue2))
}

func testAccModelConfig_primaryContainerDataURL(rName string) string {
return testAccModelConfigBase(rName) + fmt.Sprintf(`
return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
Expand Down Expand Up @@ -578,11 +606,11 @@ resource "aws_s3_object" "test" {
key = "model.tar.gz"
content = "some-data"
}
`, rName)
`, rName))
}

func testAccModelConfig_primaryContainerHostname(rName string) string {
return testAccModelConfigBase(rName) + fmt.Sprintf(`
return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
Expand All @@ -592,11 +620,11 @@ resource "aws_sagemaker_model" "test" {
container_hostname = "test"
}
}
`, rName)
`, rName))
}

func testAccModelConfig_primaryContainerImage(rName string) string {
return testAccModelConfigBase(rName) + fmt.Sprintf(`
return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
Expand All @@ -609,11 +637,11 @@ resource "aws_sagemaker_model" "test" {
}
}
}
`, rName)
`, rName))
}

func testAccModelConfig_primaryContainerEnvironment(rName string) string {
return testAccModelConfigBase(rName) + fmt.Sprintf(`
return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
Expand All @@ -626,11 +654,11 @@ resource "aws_sagemaker_model" "test" {
}
}
}
`, rName)
`, rName))
}

func testAccModelConfig_primaryContainerModeSingle(rName string) string {
return testAccModelConfigBase(rName) + fmt.Sprintf(`
return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
Expand All @@ -640,11 +668,11 @@ resource "aws_sagemaker_model" "test" {
mode = "SingleModel"
}
}
`, rName)
`, rName))
}

func testAccModelConfig_containers(rName string) string {
return testAccModelConfigBase(rName) + fmt.Sprintf(`
return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
Expand All @@ -657,11 +685,11 @@ resource "aws_sagemaker_model" "test" {
image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path
}
}
`, rName)
`, rName))
}

func testAccModelConfig_networkIsolation(rName string) string {
return testAccModelConfigBase(rName) + fmt.Sprintf(`
return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
Expand All @@ -671,13 +699,11 @@ resource "aws_sagemaker_model" "test" {
image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path
}
}
`, rName)
`, rName))
}

func testAccModelConfig_vpcBasic(rName string) string {
return testAccModelConfigBase(rName) +
acctest.ConfigAvailableAZsNoOptIn() +
fmt.Sprintf(`
return acctest.ConfigCompose(testAccModelConfigBase(rName), acctest.ConfigAvailableAZsNoOptIn(), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
Expand Down Expand Up @@ -738,5 +764,60 @@ resource "aws_security_group" "bar" {
Name = %[1]q
}
}
`, rName)
`, rName))
}

//lintignore:AWSAT003,AWSAT005
func testAccModelConfig_primaryContainerPrivateDockerRegistry(rName string) string {
return acctest.ConfigCompose(testAccModelConfigBase(rName), acctest.ConfigAvailableAZsNoOptIn(), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
enable_network_isolation = true
primary_container {
image = "registry.example.com/test-model"
image_config {
repository_access_mode = "Vpc"
repository_auth_config {
repository_credentials_provider_arn = "arn:aws:lambda:us-east-2:123456789012:function:my-function:1"
}
}
}
vpc_config {
subnets = [aws_subnet.test.id]
security_group_ids = [aws_security_group.test.id]
}
}
resource "aws_vpc" "test" {
cidr_block = "10.1.0.0/16"
tags = {
Name = %[1]q
}
}
resource "aws_subnet" "test" {
cidr_block = "10.1.1.0/24"
availability_zone = data.aws_availability_zones.available.names[0]
vpc_id = aws_vpc.test.id
tags = {
Name = %[1]q
}
}
resource "aws_security_group" "test" {
name = "%[1]s-1"
vpc_id = aws_vpc.test.id
tags = {
Name = %[1]q
}
}
`, rName))
}
5 changes: 5 additions & 0 deletions website/docs/r/sagemaker_model.html.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ The `primary_container` and `container` block both support:
### Image Config

* `repository_access_mode` - (Required) Specifies whether the model container is in Amazon ECR or a private Docker registry accessible from your Amazon Virtual Private Cloud (VPC). Allowed values are: `Platform` and `Vpc`.
* `repository_auth_config` - (Optional) Specifies an authentication configuration for the private docker registry where your model image is hosted. Specify a value for this property only if you specified Vpc as the value for the RepositoryAccessMode field, and the private Docker registry where the model image is hosted requires authentication. see [Repository Auth Config](#repository-auth-config).

#### Repository Auth Config

* `repository_credentials_provider_arn` - (Required) The Amazon Resource Name (ARN) of an AWS Lambda function that provides credentials to authenticate to the private Docker registry where your model image is hosted. For information about how to create an AWS Lambda function, see [Create a Lambda function with the console](https://docs.aws.amazon.com/lambda/latest/dg/getting-started-create-function.html) in the _AWS Lambda Developer Guide_.

## Inference Execution Config

Expand Down

0 comments on commit 1558ba3

Please sign in to comment.