diff --git a/.changelog/35479.txt b/.changelog/35479.txt new file mode 100644 index 000000000000..aa082e3a38d7 --- /dev/null +++ b/.changelog/35479.txt @@ -0,0 +1,3 @@ +```release-note:enhancement +resource/aws_sagemaker_endpoint_configuration: Add `production_variants.managed_instance_scaling` block and `shadow_production_variants.managed_instance_scaling` block +``` \ No newline at end of file diff --git a/internal/service/sagemaker/endpoint_configuration.go b/internal/service/sagemaker/endpoint_configuration.go index 69637f4ef76b..cdc59f649151 100644 --- a/internal/service/sagemaker/endpoint_configuration.go +++ b/internal/service/sagemaker/endpoint_configuration.go @@ -377,6 +377,34 @@ func resourceEndpointConfiguration() *schema.Resource { }, }, }, + "managed_instance_scaling": { + Type: schema.TypeList, + Optional: true, + MaxItems: 1, + ForceNew: true, + Elem: &schema.Resource{ + Schema: map[string]*schema.Schema{ + names.AttrStatus: { + Type: schema.TypeString, + Optional: true, + ForceNew: true, + ValidateDiagFunc: enum.Validate[awstypes.ManagedInstanceScalingStatus](), + }, + "min_instance_count": { + Type: schema.TypeInt, + Optional: true, + ForceNew: true, + ValidateFunc: validation.IntAtLeast(1), + }, + "max_instance_count": { + Type: schema.TypeInt, + Optional: true, + ForceNew: true, + ValidateFunc: validation.IntAtLeast(1), + }, + }, + }, + }, "variant_name": { Type: schema.TypeString, Optional: true, @@ -521,6 +549,34 @@ func resourceEndpointConfiguration() *schema.Resource { }, }, }, + "managed_instance_scaling": { + Type: schema.TypeList, + Optional: true, + MaxItems: 1, + ForceNew: true, + Elem: &schema.Resource{ + Schema: map[string]*schema.Schema{ + names.AttrStatus: { + Type: schema.TypeString, + Optional: true, + ForceNew: true, + ValidateDiagFunc: enum.Validate[awstypes.ManagedInstanceScalingStatus](), + }, + "min_instance_count": { + Type: schema.TypeInt, + Optional: true, + ForceNew: true, + ValidateFunc: validation.IntAtLeast(1), + }, + "max_instance_count": { + Type: schema.TypeInt, + Optional: true, + ForceNew: true, + ValidateFunc: validation.IntAtLeast(1), + }, + }, + }, + }, "variant_name": { Type: schema.TypeString, Optional: true, @@ -737,6 +793,10 @@ func expandProductionVariants(configured []interface{}) []awstypes.ProductionVar l.EnableSSMAccess = aws.Bool(v) } + if v, ok := data["managed_instance_scaling"].([]interface{}); ok && len(v) > 0 { + l.ManagedInstanceScaling = expandManagedInstanceScaling(v) + } + if v, ok := data["inference_ami_version"].(string); ok && v != "" { l.InferenceAmiVersion = awstypes.ProductionVariantInferenceAmiVersion(v) } @@ -792,6 +852,10 @@ func flattenProductionVariants(list []awstypes.ProductionVariant) []map[string]i l["enable_ssm_access"] = aws.ToBool(i.EnableSSMAccess) } + if i.ManagedInstanceScaling != nil { + l["managed_instance_scaling"] = flattenManagedInstanceScaling(i.ManagedInstanceScaling) + } + result = append(result, l) } return result @@ -1056,6 +1120,30 @@ func expandCoreDumpConfig(configured []interface{}) *awstypes.ProductionVariantC return c } +func expandManagedInstanceScaling(configured []interface{}) *awstypes.ProductionVariantManagedInstanceScaling { + if len(configured) == 0 { + return nil + } + + m := configured[0].(map[string]interface{}) + + c := &awstypes.ProductionVariantManagedInstanceScaling{} + + if v, ok := m[names.AttrStatus].(string); ok { + c.Status = awstypes.ManagedInstanceScalingStatus(v) + } + + if v, ok := m["min_instance_count"].(int); ok && v > 0 { + c.MinInstanceCount = aws.Int32(int32(v)) + } + + if v, ok := m["max_instance_count"].(int); ok && v > 0 { + c.MaxInstanceCount = aws.Int32(int32(v)) + } + + return c +} + func flattenEndpointConfigAsyncInferenceConfig(config *awstypes.AsyncInferenceConfig) []map[string]interface{} { if config == nil { return []map[string]interface{}{} @@ -1185,3 +1273,23 @@ func flattenCoreDumpConfig(config *awstypes.ProductionVariantCoreDumpConfig) []m return []map[string]interface{}{cfg} } + +func flattenManagedInstanceScaling(config *awstypes.ProductionVariantManagedInstanceScaling) []map[string]interface{} { + if config == nil { + return []map[string]interface{}{} + } + + cfg := map[string]interface{}{ + names.AttrStatus: config.Status, + } + + if config.MinInstanceCount != nil { + cfg["min_instance_count"] = aws.ToInt32(config.MinInstanceCount) + } + + if config.MaxInstanceCount != nil { + cfg["max_instance_count"] = aws.ToInt32(config.MaxInstanceCount) + } + + return []map[string]interface{}{cfg} +} diff --git a/internal/service/sagemaker/endpoint_configuration_test.go b/internal/service/sagemaker/endpoint_configuration_test.go index 2156bf772119..5a5e5ebbf514 100644 --- a/internal/service/sagemaker/endpoint_configuration_test.go +++ b/internal/service/sagemaker/endpoint_configuration_test.go @@ -761,6 +761,41 @@ func TestAccSageMakerEndpointConfiguration_upgradeToEnableSSMAccess(t *testing.T }) } +func TestAccSageMakerEndpointConfiguration_productionVariantsManagedInstanceScaling(t *testing.T) { + ctx := acctest.Context(t) + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_sagemaker_endpoint_configuration.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(ctx, t) }, + ErrorCheck: acctest.ErrorCheck(t, names.SageMakerServiceID), + ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories, + CheckDestroy: testAccCheckEndpointConfigurationDestroy(ctx), + Steps: []resource.TestStep{ + { + Config: testAccEndpointConfigurationConfig_productionVariantsManagedInstanceScaling(rName), + Check: resource.ComposeAggregateTestCheckFunc( + testAccCheckEndpointConfigurationExists(ctx, resourceName), + resource.TestCheckResourceAttr(resourceName, names.AttrName, rName), + resource.TestCheckResourceAttr(resourceName, "production_variants.#", acctest.Ct1), + resource.TestCheckResourceAttr(resourceName, "production_variants.0.variant_name", "variant-1"), + resource.TestCheckResourceAttr(resourceName, "production_variants.0.model_name", rName), + resource.TestCheckResourceAttr(resourceName, "production_variants.0.initial_instance_count", acctest.Ct1), + resource.TestCheckResourceAttr(resourceName, "production_variants.0.instance_type", "ml.g5.4xlarge"), + resource.TestCheckResourceAttr(resourceName, "production_variants.0.managed_instance_scaling.0.status", "ENABLED"), + resource.TestCheckResourceAttr(resourceName, "production_variants.0.managed_instance_scaling.0.min_instance_count", acctest.Ct1), + resource.TestCheckResourceAttr(resourceName, "production_variants.0.managed_instance_scaling.0.max_instance_count", acctest.Ct2), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + func testAccCheckEndpointConfigurationDestroy(ctx context.Context) resource.TestCheckFunc { return func(s *terraform.State) error { conn := acctest.Provider.Meta().(*conns.AWSClient).SageMakerClient(ctx) @@ -1395,3 +1430,131 @@ resource "aws_sagemaker_endpoint_configuration" "test" { } `, rName)) } + +func testAccEndpointConfigurationConfig_productionVariantsManagedInstanceScaling(rName string) string { + return acctest.ConfigCompose(fmt.Sprintf(` +data "aws_region" "current" {} +data "aws_partition" "current" {} +data "aws_sagemaker_prebuilt_ecr_image" "managed_instance_scaling_test" { + repository_name = "djl-inference" + image_tag = "0.27.0-deepspeed0.12.6-cu121" +} + +data "aws_iam_policy_document" "managed_instance_scaling_test_policy" { + statement { + effect = "Allow" + + actions = [ + "cloudwatch:PutMetricData", + "logs:CreateLogStream", + "logs:PutLogEvents", + "logs:CreateLogGroup", + "logs:DescribeLogStreams", + "ecr:GetAuthorizationToken", + "ecr:BatchCheckLayerAvailability", + "ecr:GetDownloadUrlForLayer", + "ecr:BatchGetImage", + ] + + resources = [ + "*", + ] + } + + statement { + effect = "Allow" + + actions = [ + "s3:GetObject", + "s3:ListBucket", + ] + + resources = [ + "${aws_s3_bucket.managed_instance_scaling_test.arn}", + "${aws_s3_bucket.managed_instance_scaling_test.arn}/*", + ] + } +} + +resource "aws_iam_policy" "managed_instance_scaling_test" { + name = %[1]q + description = "Allow SageMaker to create model" + policy = data.aws_iam_policy_document.managed_instance_scaling_test_policy.json +} + +resource "aws_iam_role" "managed_instance_scaling_test" { + name = %[1]q + path = "/" + assume_role_policy = data.aws_iam_policy_document.assume_role.json +} + +data "aws_iam_policy_document" "assume_role" { + statement { + actions = ["sts:AssumeRole"] + + principals { + type = "Service" + identifiers = ["sagemaker.amazonaws.com"] + } + } +} + +resource "aws_iam_role_policy_attachment" "managed_instance_scaling_test" { + role = aws_iam_role.managed_instance_scaling_test.name + policy_arn = aws_iam_policy.managed_instance_scaling_test.arn +} + +resource "aws_s3_bucket" "managed_instance_scaling_test" { + bucket = %[1]q + force_destroy = true +} + +resource "aws_s3_object" "managed_instance_scaling_test" { + bucket = aws_s3_bucket.managed_instance_scaling_test.bucket + key = "model/inference.py" + content = "some-data" +} + +resource "aws_sagemaker_model" "managed_instance_scaling_test" { + name = %[1]q + execution_role_arn = aws_iam_role.managed_instance_scaling_test.arn + primary_container { + image = data.aws_sagemaker_prebuilt_ecr_image.managed_instance_scaling_test.registry_path + model_data_source { + s3_data_source { + s3_data_type = "S3Prefix" + s3_uri = "s3://${aws_s3_object.managed_instance_scaling_test.bucket}/model/" + compression_type = "None" + } + } + } + depends_on = [ + aws_iam_role_policy_attachment.managed_instance_scaling_test + ] +} + +resource "aws_sagemaker_endpoint_configuration" "test" { + name = %[1]q + + production_variants { + variant_name = "variant-1" + model_name = aws_sagemaker_model.managed_instance_scaling_test.name + initial_instance_count = 1 + instance_type = "ml.g5.4xlarge" + + managed_instance_scaling { + status = "ENABLED" + min_instance_count = 1 + max_instance_count = 2 + } + + routing_config { + routing_strategy = "LEAST_OUTSTANDING_REQUESTS" + } + + model_data_download_timeout_in_seconds = 60 + container_startup_health_check_timeout_in_seconds = 60 + } +} +`, rName)) +} diff --git a/website/docs/r/sagemaker_endpoint_configuration.html.markdown b/website/docs/r/sagemaker_endpoint_configuration.html.markdown index 0cf7337ccfab..981d460fdfda 100644 --- a/website/docs/r/sagemaker_endpoint_configuration.html.markdown +++ b/website/docs/r/sagemaker_endpoint_configuration.html.markdown @@ -58,6 +58,7 @@ This resource supports the following arguments: * `model_name` - (Required) The name of the model to use. * `routing_config` - (Optional) Sets how the endpoint routes incoming traffic. See [routing_config](#routing_config) below. * `serverless_config` - (Optional) Specifies configuration for how an endpoint performs asynchronous inference. +* `managed_instance_scaling` - (Optional) Settings that control the range in the number of instances that the endpoint provisions as it scales up or down to accommodate traffic. * `variant_name` - (Optional) The name of the variant. If omitted, Terraform will assign a random, unique name. * `volume_size_in_gb` - (Optional) The size, in GB, of the ML storage volume attached to individual inference instance associated with the production variant. Valid values between `1` and `512`. @@ -76,6 +77,12 @@ This resource supports the following arguments: * `memory_size_in_mb` - (Required) The memory size of your serverless endpoint. Valid values are in 1 GB increments: `1024` MB, `2048` MB, `3072` MB, `4096` MB, `5120` MB, or `6144` MB. * `provisioned_concurrency` - The amount of provisioned concurrency to allocate for the serverless endpoint. Should be less than or equal to `max_concurrency`. Valid values are between `1` and `200`. +#### managed_instance_scaling + +* `status` - (Optional) Indicates whether managed instance scaling is enabled. Valid values are `ENABLED` and `DISABLED`. +* `min_instance_count` - (Optional) The minimum number of instances that the endpoint must retain when it scales down to accommodate a decrease in traffic. +* `max_instance_count` - (Optional) The maximum number of instances that the endpoint can provision when it scales up to accommodate an increase in traffic. + ### data_capture_config * `initial_sampling_percentage` - (Required) Portion of data to capture. Should be between 0 and 100.