Skip to content

Commit

Permalink
Merge pull request #28159 from DrFaust92/sagemaker-end-shadow
Browse files Browse the repository at this point in the history
r/sagemaker_endpoint_configurtion - add `shadow_production_variants` and other prod variants args
  • Loading branch information
ewbankkit authored Dec 12, 2022
2 parents efbeea0 + d313b21 commit 4e986fe
Show file tree
Hide file tree
Showing 4 changed files with 346 additions and 34 deletions.
3 changes: 3 additions & 0 deletions .changelog/28159.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:enhancement
resource/aws_sagemaker_endpoint_configuration: Add `shadow_production_variants`, `production_variants.container_startup_health_check_timeout_in_seconds`, `production_variants.core_dump_config`, `production_variants.model_data_download_timeout_in_seconds`, and `production_variants.volume_size_in_gb` arguments
```
233 changes: 233 additions & 0 deletions internal/service/sagemaker/endpoint_configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,37 @@ func ResourceEndpointConfiguration() *schema.Resource {
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.ProductionVariantAcceleratorType_Values(), false),
},
"container_startup_health_check_timeout_in_seconds": {
Type: schema.TypeInt,
Optional: true,
ForceNew: true,
ValidateFunc: validation.IntBetween(60, 3600),
},
"core_dump_config": {
Type: schema.TypeList,
Optional: true,
MaxItems: 1,
ForceNew: true,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"destination_s3_uri": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validation.All(
validation.StringMatch(regexp.MustCompile(`^(https|s3)://([^/])/?(.*)$`), ""),
validation.StringLenBetween(1, 512),
),
},
"kms_key_id": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
ValidateFunc: verify.ValidARN,
},
},
},
},
"initial_instance_count": {
Type: schema.TypeInt,
Optional: true,
Expand All @@ -246,6 +277,12 @@ func ResourceEndpointConfiguration() *schema.Resource {
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.ProductionVariantInstanceType_Values(), false),
},
"model_data_download_timeout_in_seconds": {
Type: schema.TypeInt,
Optional: true,
ForceNew: true,
ValidateFunc: validation.IntBetween(60, 3600),
},
"model_name": {
Type: schema.TypeString,
Required: true,
Expand Down Expand Up @@ -279,6 +316,124 @@ func ResourceEndpointConfiguration() *schema.Resource {
Computed: true,
ForceNew: true,
},
"volume_size_in_gb": {
Type: schema.TypeInt,
Optional: true,
Computed: true,
ForceNew: true,
ValidateFunc: validation.IntBetween(1, 512),
},
},
},
},
"shadow_production_variants": {
Type: schema.TypeList,
Optional: true,
MinItems: 1,
MaxItems: 10,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"accelerator_type": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.ProductionVariantAcceleratorType_Values(), false),
},
"container_startup_health_check_timeout_in_seconds": {
Type: schema.TypeInt,
Optional: true,
ForceNew: true,
ValidateFunc: validation.IntBetween(60, 3600),
},
"core_dump_config": {
Type: schema.TypeList,
Optional: true,
MaxItems: 1,
ForceNew: true,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"destination_s3_uri": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: validation.All(
validation.StringMatch(regexp.MustCompile(`^(https|s3)://([^/])/?(.*)$`), ""),
validation.StringLenBetween(1, 512),
),
},
"kms_key_id": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
ValidateFunc: verify.ValidARN,
},
},
},
},
"initial_instance_count": {
Type: schema.TypeInt,
Optional: true,
ForceNew: true,
ValidateFunc: validation.IntAtLeast(1),
},
"initial_variant_weight": {
Type: schema.TypeFloat,
Optional: true,
ForceNew: true,
ValidateFunc: validation.FloatAtLeast(0),
Default: 1,
},
"instance_type": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
ValidateFunc: validation.StringInSlice(sagemaker.ProductionVariantInstanceType_Values(), false),
},
"model_data_download_timeout_in_seconds": {
Type: schema.TypeInt,
Optional: true,
ForceNew: true,
ValidateFunc: validation.IntBetween(60, 3600),
},
"model_name": {
Type: schema.TypeString,
Required: true,
ForceNew: true,
},
"serverless_config": {
Type: schema.TypeList,
Optional: true,
MaxItems: 1,
ForceNew: true,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"max_concurrency": {
Type: schema.TypeInt,
Required: true,
ForceNew: true,
ValidateFunc: validation.IntBetween(1, 200),
},
"memory_size_in_mb": {
Type: schema.TypeInt,
Required: true,
ForceNew: true,
ValidateFunc: validation.IntInSlice([]int{1024, 2048, 3072, 4096, 5120, 6144}),
},
},
},
},
"variant_name": {
Type: schema.TypeString,
Optional: true,
Computed: true,
ForceNew: true,
},
"volume_size_in_gb": {
Type: schema.TypeInt,
Optional: true,
ForceNew: true,
ValidateFunc: validation.IntBetween(1, 512),
},
},
},
},
Expand Down Expand Up @@ -315,6 +470,10 @@ func resourceEndpointConfigurationCreate(d *schema.ResourceData, meta interface{
createOpts.Tags = Tags(tags.IgnoreAWS())
}

if v, ok := d.GetOk("shadow_production_variants"); ok && len(v.([]interface{})) > 0 {
createOpts.ShadowProductionVariants = expandProductionVariants(v.([]interface{}))
}

if v, ok := d.GetOk("data_capture_config"); ok {
createOpts.DataCaptureConfig = expandDataCaptureConfig(v.([]interface{}))
}
Expand Down Expand Up @@ -358,6 +517,10 @@ func resourceEndpointConfigurationRead(d *schema.ResourceData, meta interface{})
return fmt.Errorf("setting production_variants for SageMaker Endpoint Configuration (%s): %w", d.Id(), err)
}

if err := d.Set("shadow_production_variants", flattenProductionVariants(endpointConfig.ShadowProductionVariants)); err != nil {
return fmt.Errorf("setting shadow_production_variants for SageMaker Endpoint Configuration (%s): %w", d.Id(), err)
}

if err := d.Set("data_capture_config", flattenDataCaptureConfig(endpointConfig.DataCaptureConfig)); err != nil {
return fmt.Errorf("setting data_capture_config for SageMaker Endpoint Configuration (%s): %w", d.Id(), err)
}
Expand Down Expand Up @@ -433,6 +596,18 @@ func expandProductionVariants(configured []interface{}) []*sagemaker.ProductionV
l.InitialInstanceCount = aws.Int64(int64(v))
}

if v, ok := data["container_startup_health_check_timeout_in_seconds"].(int); ok && v > 0 {
l.ContainerStartupHealthCheckTimeoutInSeconds = aws.Int64(int64(v))
}

if v, ok := data["model_data_download_timeout_in_seconds"].(int); ok && v > 0 {
l.ModelDataDownloadTimeoutInSeconds = aws.Int64(int64(v))
}

if v, ok := data["volume_size_in_gb"].(int); ok && v > 0 {
l.VolumeSizeInGB = aws.Int64(int64(v))
}

if v, ok := data["instance_type"].(string); ok && v != "" {
l.InstanceType = aws.String(v)
}
Expand All @@ -455,6 +630,10 @@ func expandProductionVariants(configured []interface{}) []*sagemaker.ProductionV
l.ServerlessConfig = expandServerlessConfig(v)
}

if v, ok := data["core_dump_config"].([]interface{}); ok && len(v) > 0 {
l.CoreDumpConfig = expandCoreDumpConfig(v)
}

containers = append(containers, l)
}

Expand All @@ -476,6 +655,18 @@ func flattenProductionVariants(list []*sagemaker.ProductionVariant) []map[string
l["initial_instance_count"] = aws.Int64Value(i.InitialInstanceCount)
}

if i.ContainerStartupHealthCheckTimeoutInSeconds != nil {
l["container_startup_health_check_timeout_in_seconds"] = aws.Int64Value(i.ContainerStartupHealthCheckTimeoutInSeconds)
}

if i.ModelDataDownloadTimeoutInSeconds != nil {
l["model_data_download_timeout_in_seconds"] = aws.Int64Value(i.ModelDataDownloadTimeoutInSeconds)
}

if i.VolumeSizeInGB != nil {
l["volume_size_in_gb"] = aws.Int64Value(i.VolumeSizeInGB)
}

if i.InstanceType != nil {
l["instance_type"] = aws.StringValue(i.InstanceType)
}
Expand All @@ -484,6 +675,10 @@ func flattenProductionVariants(list []*sagemaker.ProductionVariant) []map[string
l["serverless_config"] = flattenServerlessConfig(i.ServerlessConfig)
}

if i.CoreDumpConfig != nil {
l["core_dump_config"] = flattenCoreDumpConfig(i.CoreDumpConfig)
}

result = append(result, l)
}
return result
Expand Down Expand Up @@ -700,6 +895,26 @@ func expandServerlessConfig(configured []interface{}) *sagemaker.ProductionVaria
return c
}

func expandCoreDumpConfig(configured []interface{}) *sagemaker.ProductionVariantCoreDumpConfig {
if len(configured) == 0 {
return nil
}

m := configured[0].(map[string]interface{})

c := &sagemaker.ProductionVariantCoreDumpConfig{}

if v, ok := m["destination_s3_uri"].(string); ok {
c.DestinationS3Uri = aws.String(v)
}

if v, ok := m["kms_key_id"].(string); ok {
c.KmsKeyId = aws.String(v)
}

return c
}

func flattenEndpointConfigAsyncInferenceConfig(config *sagemaker.AsyncInferenceConfig) []map[string]interface{} {
if config == nil {
return []map[string]interface{}{}
Expand Down Expand Up @@ -787,3 +1002,21 @@ func flattenServerlessConfig(config *sagemaker.ProductionVariantServerlessConfig

return []map[string]interface{}{cfg}
}

func flattenCoreDumpConfig(config *sagemaker.ProductionVariantCoreDumpConfig) []map[string]interface{} {
if config == nil {
return []map[string]interface{}{}
}

cfg := map[string]interface{}{}

if config.DestinationS3Uri != nil {
cfg["destination_s3_uri"] = aws.StringValue(config.DestinationS3Uri)
}

if config.KmsKeyId != nil {
cfg["kms_key_id"] = aws.StringValue(config.KmsKeyId)
}

return []map[string]interface{}{cfg}
}
Loading

0 comments on commit 4e986fe

Please sign in to comment.