Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

s/sagemaker Add support repository_auth_config in image_config #25557

Merged
merged 6 commits into from
Jun 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .changelog/25557.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:enhancement
resource/sagemaker: Add `repository_auth_config` arguments in support of [Private Docker Registry](https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-containers-inference-private.html)
```
70 changes: 68 additions & 2 deletions internal/service/sagemaker/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ func ResourceModel() *schema.Resource {
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.RepositoryAccessMode_Values(), false),
},
"repository_auth_config": {
Type: schema.TypeList,
Optional: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"repository_credentials_provider_arn": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: verify.ValidARN,
},
},
},
},
},
},
},
Expand Down Expand Up @@ -159,6 +174,21 @@ func ResourceModel() *schema.Resource {
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.RepositoryAccessMode_Values(), false),
},
"repository_auth_config": {
Type: schema.TypeList,
Optional: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"repository_credentials_provider_arn": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: verify.ValidARN,
},
},
},
},
},
},
},
Expand Down Expand Up @@ -408,8 +438,8 @@ func expandContainer(m map[string]interface{}) *sagemaker.ContainerDefinition {
if v, ok := m["model_data_url"]; ok && v.(string) != "" {
container.ModelDataUrl = aws.String(v.(string))
}
if v, ok := m["environment"]; ok {
container.Environment = flex.ExpandStringMap(v.(map[string]interface{}))
if v, ok := m["environment"].(map[string]interface{}); ok && len(v) > 0 {
container.Environment = flex.ExpandStringMap(v)
}

if v, ok := m["image_config"]; ok {
Expand All @@ -430,9 +460,27 @@ func expandModelImageConfig(l []interface{}) *sagemaker.ImageConfig {
RepositoryAccessMode: aws.String(m["repository_access_mode"].(string)),
}

if v, ok := m["repository_auth_config"].([]interface{}); ok && len(v) > 0 && v[0] != nil {
imageConfig.RepositoryAuthConfig = expandRepositoryAuthConfig(v[0].(map[string]interface{}))
}

return imageConfig
}

func expandRepositoryAuthConfig(tfMap map[string]interface{}) *sagemaker.RepositoryAuthConfig {
if tfMap == nil {
return nil
}

apiObject := &sagemaker.RepositoryAuthConfig{}

if v, ok := tfMap["repository_credentials_provider_arn"].(string); ok && v != "" {
apiObject.RepositoryCredentialsProviderArn = aws.String(v)
}

return apiObject
}

func expandContainers(a []interface{}) []*sagemaker.ContainerDefinition {
containers := make([]*sagemaker.ContainerDefinition, 0, len(a))

Expand Down Expand Up @@ -482,9 +530,27 @@ func flattenImageConfig(imageConfig *sagemaker.ImageConfig) []interface{} {

cfg["repository_access_mode"] = aws.StringValue(imageConfig.RepositoryAccessMode)

if tfMap := flattenRepositoryAuthConfig(imageConfig.RepositoryAuthConfig); len(tfMap) > 0 {
cfg["repository_auth_config"] = []interface{}{tfMap}
}

return []interface{}{cfg}
}

func flattenRepositoryAuthConfig(apiObject *sagemaker.RepositoryAuthConfig) map[string]interface{} {
if apiObject == nil {
return nil
}

tfMap := make(map[string]interface{})

if v := apiObject.RepositoryCredentialsProviderArn; v != nil {
tfMap["repository_credentials_provider_arn"] = aws.StringValue(v)
}

return tfMap
}

func flattenContainers(containers []*sagemaker.ContainerDefinition) []interface{} {
fContainers := make([]interface{}, 0, len(containers))
for _, container := range containers {
Expand Down
133 changes: 107 additions & 26 deletions internal/service/sagemaker/model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,34 @@ func TestAccSageMakerModel_vpc(t *testing.T) {
})
}

func TestAccSageMakerModel_primaryContainerPrivateDockerRegistry(t *testing.T) {
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
resourceName := "aws_sagemaker_model.test"

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { acctest.PreCheck(t) },
ErrorCheck: acctest.ErrorCheck(t, sagemaker.EndpointsID),
ProviderFactories: acctest.ProviderFactories,
CheckDestroy: testAccCheckModelDestroy,
Steps: []resource.TestStep{
{
Config: testAccModelConfig_primaryContainerPrivateDockerRegistry(rName),
Check: resource.ComposeTestCheckFunc(
testAccCheckModelExists(resourceName),
resource.TestCheckResourceAttr(resourceName, "primary_container.0.image_config.#", "1"),
resource.TestCheckResourceAttr(resourceName, "primary_container.0.image_config.0.repository_access_mode", "Vpc"),
resource.TestCheckResourceAttr(resourceName, "primary_container.0.image_config.0.repository_auth_config.0.repository_credentials_provider_arn", "arn:aws:lambda:us-east-2:123456789012:function:my-function:1"), //lintignore:AWSAT003,AWSAT005
),
},
{
ResourceName: resourceName,
ImportState: true,
ImportStateVerify: true,
},
},
})
}

func TestAccSageMakerModel_networkIsolation(t *testing.T) {
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
resourceName := "aws_sagemaker_model.test"
Expand Down Expand Up @@ -438,7 +466,7 @@ data "aws_sagemaker_prebuilt_ecr_image" "test" {
}

func testAccModelConfig_basic(rName string) string {
return testAccModelConfigBase(rName) + fmt.Sprintf(`
return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
Expand All @@ -447,11 +475,11 @@ resource "aws_sagemaker_model" "test" {
image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path
}
}
`, rName)
`, rName))
}

func testAccModelConfig_inferenceExecution(rName string) string {
return testAccModelConfigBase(rName) + fmt.Sprintf(`
return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
Expand All @@ -468,11 +496,11 @@ resource "aws_sagemaker_model" "test" {
image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path
}
}
`, rName)
`, rName))
}

func testAccModelConfig_tags1(rName, tagKey1, tagValue1 string) string {
return testAccModelConfigBase(rName) + fmt.Sprintf(`
return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
Expand All @@ -485,11 +513,11 @@ resource "aws_sagemaker_model" "test" {
%[2]q = %[3]q
}
}
`, rName, tagKey1, tagValue1)
`, rName, tagKey1, tagValue1))
}

func testAccModelConfig_tags2(rName, tagKey1, tagValue1, tagKey2, tagValue2 string) string {
return testAccModelConfigBase(rName) + fmt.Sprintf(`
return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
Expand All @@ -503,11 +531,11 @@ resource "aws_sagemaker_model" "test" {
%[4]q = %[5]q
}
}
`, rName, tagKey1, tagValue1, tagKey2, tagValue2)
`, rName, tagKey1, tagValue1, tagKey2, tagValue2))
}

func testAccModelConfig_primaryContainerDataURL(rName string) string {
return testAccModelConfigBase(rName) + fmt.Sprintf(`
return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
Expand Down Expand Up @@ -578,11 +606,11 @@ resource "aws_s3_object" "test" {
key = "model.tar.gz"
content = "some-data"
}
`, rName)
`, rName))
}

func testAccModelConfig_primaryContainerHostname(rName string) string {
return testAccModelConfigBase(rName) + fmt.Sprintf(`
return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
Expand All @@ -592,11 +620,11 @@ resource "aws_sagemaker_model" "test" {
container_hostname = "test"
}
}
`, rName)
`, rName))
}

func testAccModelConfig_primaryContainerImage(rName string) string {
return testAccModelConfigBase(rName) + fmt.Sprintf(`
return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
Expand All @@ -609,11 +637,11 @@ resource "aws_sagemaker_model" "test" {
}
}
}
`, rName)
`, rName))
}

func testAccModelConfig_primaryContainerEnvironment(rName string) string {
return testAccModelConfigBase(rName) + fmt.Sprintf(`
return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
Expand All @@ -626,11 +654,11 @@ resource "aws_sagemaker_model" "test" {
}
}
}
`, rName)
`, rName))
}

func testAccModelConfig_primaryContainerModeSingle(rName string) string {
return testAccModelConfigBase(rName) + fmt.Sprintf(`
return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
Expand All @@ -640,11 +668,11 @@ resource "aws_sagemaker_model" "test" {
mode = "SingleModel"
}
}
`, rName)
`, rName))
}

func testAccModelConfig_containers(rName string) string {
return testAccModelConfigBase(rName) + fmt.Sprintf(`
return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
Expand All @@ -657,11 +685,11 @@ resource "aws_sagemaker_model" "test" {
image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path
}
}
`, rName)
`, rName))
}

func testAccModelConfig_networkIsolation(rName string) string {
return testAccModelConfigBase(rName) + fmt.Sprintf(`
return acctest.ConfigCompose(testAccModelConfigBase(rName), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
Expand All @@ -671,13 +699,11 @@ resource "aws_sagemaker_model" "test" {
image = data.aws_sagemaker_prebuilt_ecr_image.test.registry_path
}
}
`, rName)
`, rName))
}

func testAccModelConfig_vpcBasic(rName string) string {
return testAccModelConfigBase(rName) +
acctest.ConfigAvailableAZsNoOptIn() +
fmt.Sprintf(`
return acctest.ConfigCompose(testAccModelConfigBase(rName), acctest.ConfigAvailableAZsNoOptIn(), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
Expand Down Expand Up @@ -738,5 +764,60 @@ resource "aws_security_group" "bar" {
Name = %[1]q
}
}
`, rName)
`, rName))
}

//lintignore:AWSAT003,AWSAT005
func testAccModelConfig_primaryContainerPrivateDockerRegistry(rName string) string {
return acctest.ConfigCompose(testAccModelConfigBase(rName), acctest.ConfigAvailableAZsNoOptIn(), fmt.Sprintf(`
resource "aws_sagemaker_model" "test" {
name = %[1]q
execution_role_arn = aws_iam_role.test.arn
enable_network_isolation = true

primary_container {
image = "registry.example.com/test-model"

image_config {
repository_access_mode = "Vpc"

repository_auth_config {
repository_credentials_provider_arn = "arn:aws:lambda:us-east-2:123456789012:function:my-function:1"
}
}
}

vpc_config {
subnets = [aws_subnet.test.id]
security_group_ids = [aws_security_group.test.id]
}
}

resource "aws_vpc" "test" {
cidr_block = "10.1.0.0/16"

tags = {
Name = %[1]q
}
}

resource "aws_subnet" "test" {
cidr_block = "10.1.1.0/24"
availability_zone = data.aws_availability_zones.available.names[0]
vpc_id = aws_vpc.test.id

tags = {
Name = %[1]q
}
}

resource "aws_security_group" "test" {
name = "%[1]s-1"
vpc_id = aws_vpc.test.id

tags = {
Name = %[1]q
}
}
`, rName))
}
5 changes: 5 additions & 0 deletions website/docs/r/sagemaker_model.html.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ The `primary_container` and `container` block both support:
### Image Config

* `repository_access_mode` - (Required) Specifies whether the model container is in Amazon ECR or a private Docker registry accessible from your Amazon Virtual Private Cloud (VPC). Allowed values are: `Platform` and `Vpc`.
* `repository_auth_config` - (Optional) Specifies an authentication configuration for the private docker registry where your model image is hosted. Specify a value for this property only if you specified Vpc as the value for the RepositoryAccessMode field, and the private Docker registry where the model image is hosted requires authentication. see [Repository Auth Config](#repository-auth-config).

#### Repository Auth Config

* `repository_credentials_provider_arn` - (Required) The Amazon Resource Name (ARN) of an AWS Lambda function that provides credentials to authenticate to the private Docker registry where your model image is hosted. For information about how to create an AWS Lambda function, see [Create a Lambda function with the console](https://docs.aws.amazon.com/lambda/latest/dg/getting-started-create-function.html) in the _AWS Lambda Developer Guide_.

## Inference Execution Config

Expand Down