From f789e1ad899dafddb5c544b76bc905def4945e3b Mon Sep 17 00:00:00 2001 From: drfaust92 Date: Tue, 7 Jun 2022 22:15:50 +0300 Subject: [PATCH 1/3] add serverless support --- .../sagemaker/endpoint_configuration.go | 220 ++++++++++++------ .../sagemaker/endpoint_configuration_test.go | 49 +++- ...maker_endpoint_configuration.html.markdown | 28 ++- 3 files changed, 211 insertions(+), 86 deletions(-) diff --git a/internal/service/sagemaker/endpoint_configuration.go b/internal/service/sagemaker/endpoint_configuration.go index 985f688c7e4..27365321719 100644 --- a/internal/service/sagemaker/endpoint_configuration.go +++ b/internal/service/sagemaker/endpoint_configuration.go @@ -106,75 +106,6 @@ func ResourceEndpointConfiguration() *schema.Resource { }, }, }, - - "name": { - Type: schema.TypeString, - Optional: true, - Computed: true, - ForceNew: true, - ValidateFunc: validName, - }, - - "production_variants": { - Type: schema.TypeList, - Required: true, - Elem: &schema.Resource{ - Schema: map[string]*schema.Schema{ - "variant_name": { - Type: schema.TypeString, - Optional: true, - Computed: true, - ForceNew: true, - }, - - "model_name": { - Type: schema.TypeString, - Required: true, - ForceNew: true, - }, - - "initial_instance_count": { - Type: schema.TypeInt, - Required: true, - ForceNew: true, - ValidateFunc: validation.IntAtLeast(1), - }, - - "instance_type": { - Type: schema.TypeString, - Required: true, - ForceNew: true, - ValidateFunc: validation.StringInSlice(sagemaker.ProductionVariantInstanceType_Values(), false), - }, - - "initial_variant_weight": { - Type: schema.TypeFloat, - Optional: true, - ForceNew: true, - ValidateFunc: validation.FloatAtLeast(0), - Default: 1, - }, - - "accelerator_type": { - Type: schema.TypeString, - Optional: true, - ForceNew: true, - ValidateFunc: validation.StringInSlice(sagemaker.ProductionVariantAcceleratorType_Values(), false), - }, - }, - }, - }, - - "kms_key_arn": { - Type: schema.TypeString, - Optional: true, - ForceNew: true, - ValidateFunc: verify.ValidARN, - }, - - "tags": tftags.TagsSchema(), - "tags_all": tftags.TagsSchemaComputed(), - "data_capture_config": { Type: schema.TypeList, MaxItems: 1, @@ -270,6 +201,89 @@ func ResourceEndpointConfiguration() *schema.Resource { }, }, }, + "name": { + Type: schema.TypeString, + Optional: true, + Computed: true, + ForceNew: true, + ValidateFunc: validName, + }, + "kms_key_arn": { + Type: schema.TypeString, + Optional: true, + ForceNew: true, + ValidateFunc: verify.ValidARN, + }, + "production_variants": { + Type: schema.TypeList, + Required: true, + MinItems: 1, + MaxItems: 10, + Elem: &schema.Resource{ + Schema: map[string]*schema.Schema{ + "accelerator_type": { + Type: schema.TypeString, + Optional: true, + ForceNew: true, + ValidateFunc: validation.StringInSlice(sagemaker.ProductionVariantAcceleratorType_Values(), false), + }, + "initial_instance_count": { + Type: schema.TypeInt, + Optional: true, + ForceNew: true, + ValidateFunc: validation.IntAtLeast(1), + }, + "initial_variant_weight": { + Type: schema.TypeFloat, + Optional: true, + ForceNew: true, + ValidateFunc: validation.FloatAtLeast(0), + Default: 1, + }, + "instance_type": { + Type: schema.TypeString, + Optional: true, + ForceNew: true, + ValidateFunc: validation.StringInSlice(sagemaker.ProductionVariantInstanceType_Values(), false), + }, + "model_name": { + Type: schema.TypeString, + Required: true, + ForceNew: true, + }, + "serverless_config": { + Type: schema.TypeList, + Optional: true, + MaxItems: 1, + ForceNew: true, + Elem: &schema.Resource{ + Schema: map[string]*schema.Schema{ + "max_concurrency": { + Type: schema.TypeInt, + Required: true, + ForceNew: true, + ValidateFunc: validation.IntBetween(1, 200), + }, + "memory_size_in_mb": { + Type: schema.TypeInt, + Required: true, + ForceNew: true, + ValidateFunc: validation.IntInSlice([]int{1024, 2048, 3072, 4096, 5120, 6144}), + }, + }, + }, + }, + "variant_name": { + Type: schema.TypeString, + Optional: true, + Computed: true, + ForceNew: true, + }, + }, + }, + }, + "tags": tftags.TagsSchema(), + "tags_all": tftags.TagsSchemaComputed(), }, CustomizeDiff: verify.SetTagsDiff, @@ -412,9 +426,15 @@ func expandProductionVariants(configured []interface{}) []*sagemaker.ProductionV data := lRaw.(map[string]interface{}) l := &sagemaker.ProductionVariant{ - InstanceType: aws.String(data["instance_type"].(string)), - ModelName: aws.String(data["model_name"].(string)), - InitialInstanceCount: aws.Int64(int64(data["initial_instance_count"].(int))), + ModelName: aws.String(data["model_name"].(string)), + } + + if v, ok := data["initial_instance_count"].(int); ok && v > 0 { + l.InitialInstanceCount = aws.Int64(int64(v)) + } + + if v, ok := data["instance_type"].(string); ok && v != "" { + l.InstanceType = aws.String(v) } if v, ok := data["variant_name"]; ok { @@ -431,6 +451,10 @@ func expandProductionVariants(configured []interface{}) []*sagemaker.ProductionV l.AcceleratorType = aws.String(v) } + if v, ok := data["serverless_config"].([]interface{}); ok && len(v) > 0 { + l.ServerlessConfig = expandServerlessConfig(v) + } + containers = append(containers, l) } @@ -443,13 +467,23 @@ func flattenProductionVariants(list []*sagemaker.ProductionVariant) []map[string for _, i := range list { l := map[string]interface{}{ "accelerator_type": aws.StringValue(i.AcceleratorType), - "initial_instance_count": aws.Int64Value(i.InitialInstanceCount), "initial_variant_weight": aws.Float64Value(i.InitialVariantWeight), - "instance_type": aws.StringValue(i.InstanceType), "model_name": aws.StringValue(i.ModelName), "variant_name": aws.StringValue(i.VariantName), } + if i.InitialInstanceCount != nil { + l["initial_instance_count"] = aws.Int64Value(i.InitialInstanceCount) + } + + if i.InstanceType != nil { + l["instance_type"] = aws.StringValue(i.InstanceType) + } + + if i.ServerlessConfig != nil { + l["serverless_config"] = flattenServerlessConfig(i.ServerlessConfig) + } + result = append(result, l) } return result @@ -646,6 +680,26 @@ func expandEndpointConfigNotificationConfig(configured []interface{}) *sagemaker return c } +func expandServerlessConfig(configured []interface{}) *sagemaker.ProductionVariantServerlessConfig { + if len(configured) == 0 { + return nil + } + + m := configured[0].(map[string]interface{}) + + c := &sagemaker.ProductionVariantServerlessConfig{} + + if v, ok := m["max_concurrency"].(int); ok { + c.MaxConcurrency = aws.Int64(int64(v)) + } + + if v, ok := m["memory_size_in_mb"].(int); ok { + c.MemorySizeInMB = aws.Int64(int64(v)) + } + + return c +} + func flattenEndpointConfigAsyncInferenceConfig(config *sagemaker.AsyncInferenceConfig) []map[string]interface{} { if config == nil { return []map[string]interface{}{} @@ -715,3 +769,21 @@ func flattenEndpointConfigNotificationConfig(config *sagemaker.AsyncInferenceNot return []map[string]interface{}{cfg} } + +func flattenServerlessConfig(config *sagemaker.ProductionVariantServerlessConfig) []map[string]interface{} { + if config == nil { + return []map[string]interface{}{} + } + + cfg := map[string]interface{}{} + + if config.MaxConcurrency != nil { + cfg["max_concurrency"] = aws.Int64Value(config.MaxConcurrency) + } + + if config.MemorySizeInMB != nil { + cfg["memory_size_in_mb"] = aws.Int64Value(config.MemorySizeInMB) + } + + return []map[string]interface{}{cfg} +} diff --git a/internal/service/sagemaker/endpoint_configuration_test.go b/internal/service/sagemaker/endpoint_configuration_test.go index 42282252ff9..2b338388fcb 100644 --- a/internal/service/sagemaker/endpoint_configuration_test.go +++ b/internal/service/sagemaker/endpoint_configuration_test.go @@ -35,7 +35,7 @@ func TestAccSageMakerEndpointConfiguration_basic(t *testing.T) { resource.TestCheckResourceAttr(resourceName, "production_variants.0.initial_instance_count", "2"), resource.TestCheckResourceAttr(resourceName, "production_variants.0.instance_type", "ml.t2.medium"), resource.TestCheckResourceAttr(resourceName, "production_variants.0.initial_variant_weight", "1"), - resource.TestCheckResourceAttr(resourceName, "production_variants.0.code_dump_config.#", "0"), + resource.TestCheckResourceAttr(resourceName, "production_variants.0.serverless_config.#", "0"), resource.TestCheckResourceAttr(resourceName, "data_capture_config.#", "0"), resource.TestCheckResourceAttr(resourceName, "async_inference_config.#", "0"), ), @@ -49,6 +49,35 @@ func TestAccSageMakerEndpointConfiguration_basic(t *testing.T) { }) } +func TestAccSageMakerEndpointConfiguration_ProductionVariants_serverless(t *testing.T) { + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + resourceName := "aws_sagemaker_endpoint_configuration.test" + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(t) }, + ErrorCheck: acctest.ErrorCheck(t, sagemaker.EndpointsID), + ProviderFactories: acctest.ProviderFactories, + CheckDestroy: testAccCheckEndpointConfigurationDestroy, + Steps: []resource.TestStep{ + { + Config: testAccEndpointConfigurationConfig_serverless(rName), + Check: resource.ComposeTestCheckFunc( + testAccCheckEndpointConfigurationExists(resourceName), + resource.TestCheckResourceAttr(resourceName, "production_variants.#", "1"), + resource.TestCheckResourceAttr(resourceName, "production_variants.0.serverless_config.#", "1"), + resource.TestCheckResourceAttr(resourceName, "production_variants.0.serverless_config.0.max_concurrency", "1"), + resource.TestCheckResourceAttr(resourceName, "production_variants.0.serverless_config.0.memory_size_in_mb", "1024"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + func TestAccSageMakerEndpointConfiguration_ProductionVariants_initialVariantWeight(t *testing.T) { rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) resourceName := "aws_sagemaker_endpoint_configuration.test" @@ -753,3 +782,21 @@ resource "aws_sagemaker_endpoint_configuration" "test" { } `, rName) } + +func testAccEndpointConfigurationConfig_serverless(rName string) string { + return testAccEndpointConfigurationConfig_Base(rName) + fmt.Sprintf(` +resource "aws_sagemaker_endpoint_configuration" "test" { + name = %q + + production_variants { + variant_name = "variant-1" + model_name = aws_sagemaker_model.test.name + + serverless_config { + max_concurrency = 1 + memory_size_in_mb = 1024 + } + } +} +`, rName) +} diff --git a/website/docs/r/sagemaker_endpoint_configuration.html.markdown b/website/docs/r/sagemaker_endpoint_configuration.html.markdown index 5467c4a0ea5..e47df8c9cd2 100644 --- a/website/docs/r/sagemaker_endpoint_configuration.html.markdown +++ b/website/docs/r/sagemaker_endpoint_configuration.html.markdown @@ -43,16 +43,22 @@ The following arguments are supported: * `data_capture_config` - (Optional) Specifies the parameters to capture input/output of SageMaker models endpoints. Fields are documented below. * `async_inference_config` - (Optional) Specifies configuration for how an endpoint performs asynchronous inference. -The `production_variants` block supports: +### production_variants -* `initial_instance_count` - (Required) Initial number of instances used for auto-scaling. -* `instance_type` (Required) - The type of instance to start. +* `initial_instance_count` - (Optional) Initial number of instances used for auto-scaling. +* `instance_type` (Optional) - The type of instance to start. * `accelerator_type` (Optional) - The size of the Elastic Inference (EI) instance to use for the production variant. -* `initial_variant_weight` (Optional) - Determines initial traffic distribution among all of the models that you specify in the endpoint configuration. If unspecified, it defaults to 1.0. +* `initial_variant_weight` (Optional) - Determines initial traffic distribution among all of the models that you specify in the endpoint configuration. If unspecified, it defaults to `1.0`. * `model_name` - (Required) The name of the model to use. * `variant_name` - (Optional) The name of the variant. If omitted, Terraform will assign a random, unique name. +* `serverless_config` - (Optional) Specifies configuration for how an endpoint performs asynchronous inference. -The `data_capture_config` block supports: +#### serverless_config + +* `max_concurrency` - (Required) The maximum number of concurrent invocations your serverless endpoint can process. Valid values are between `1` and `200`. +* `memory_size_in_mb` - (Required) The memory size of your serverless endpoint. Valid values are in 1 GB increments: `1024` MB, `2048` MB, `3072` MB, `4096` MB, `5120` MB, or `6144` MB. + +### data_capture_config * `initial_sampling_percentage` - (Required) Portion of data to capture. Should be between 0 and 100. * `destination_s3_uri` - (Required) The URL for S3 location where the captured data is stored. @@ -61,31 +67,31 @@ The `data_capture_config` block supports: * `enable_capture` - (Optional) Flag to enable data capture. Defaults to `false`. * `capture_content_type_header` - (Optional) The content type headers to capture. Fields are documented below. -The `capture_options` block supports: +#### capture_options * `capture_mode` - (Required) Specifies the data to be captured. Should be one of `Input` or `Output`. -The `capture_content_type_header` block supports: +#### capture_content_type_header * `csv_content_types` - (Optional) The CSV content type headers to capture. * `json_content_types` - (Optional) The JSON content type headers to capture. -The `async_inference_config` block supports: +### async_inference_config * `output_config` - (Required) Specifies the configuration for asynchronous inference invocation outputs. * `client_config` - (Optional) Configures the behavior of the client used by Amazon SageMaker to interact with the model container during asynchronous inference. -The `client_config` block supports: +#### client_config * `max_concurrent_invocations_per_instance` - (Optional) The maximum number of concurrent requests sent by the SageMaker client to the model container. If no value is provided, Amazon SageMaker will choose an optimal value for you. -The `output_config` block supports: +#### output_config * `s3_output_path` - (Required) The Amazon S3 location to upload inference responses to. * `kms_key_id` - (Optional) The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that Amazon SageMaker uses to encrypt the asynchronous inference output in Amazon S3. * `notification_config` - (Optional) Specifies the configuration for notifications of inference results for asynchronous inference. -The `notification_config` block supports: +##### notification_config * `error_topic` - (Optional) Amazon SNS topic to post a notification to when inference fails. If no topic is provided, no notification is sent on failure. * `success_topic` - (Optional) Amazon SNS topic to post a notification to when inference completes successfully. If no topic is provided, no notification is sent on success. From 51d1a77ae774f5f33577af9c327364cc71bb4a20 Mon Sep 17 00:00:00 2001 From: drfaust92 Date: Tue, 7 Jun 2022 22:19:08 +0300 Subject: [PATCH 2/3] changelog --- .changelog/25218.txt | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 .changelog/25218.txt diff --git a/.changelog/25218.txt b/.changelog/25218.txt new file mode 100644 index 00000000000..91a9cd5a215 --- /dev/null +++ b/.changelog/25218.txt @@ -0,0 +1,7 @@ +```release-note:enhancement +resource/aws_sagemaker_endpoint_configuration: Add `serverless_config` argument +``` + +```release-note:enhancement +resource/aws_sagemaker_endpoint_configuration: Make `production_variants.initial_instance_count` and `production_variants.instance_type` arguments optional +``` \ No newline at end of file From 986707d5949fb00e187b8ca7214e308255915010 Mon Sep 17 00:00:00 2001 From: drfaust92 Date: Tue, 7 Jun 2022 22:21:45 +0300 Subject: [PATCH 3/3] fmt --- internal/service/sagemaker/endpoint_configuration_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/service/sagemaker/endpoint_configuration_test.go b/internal/service/sagemaker/endpoint_configuration_test.go index 2b338388fcb..b0cf9c16af2 100644 --- a/internal/service/sagemaker/endpoint_configuration_test.go +++ b/internal/service/sagemaker/endpoint_configuration_test.go @@ -789,13 +789,13 @@ resource "aws_sagemaker_endpoint_configuration" "test" { name = %q production_variants { - variant_name = "variant-1" - model_name = aws_sagemaker_model.test.name + variant_name = "variant-1" + model_name = aws_sagemaker_model.test.name serverless_config { max_concurrency = 1 memory_size_in_mb = 1024 - } + } } } `, rName)