Skip to content

Commit

Permalink
Address 2nd round of feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
jckuester committed Mar 25, 2019
1 parent 077f20d commit 6ecd49f
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 325 deletions.
92 changes: 33 additions & 59 deletions aws/resource_aws_sagemaker_endpoint_configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ package aws
import (
"fmt"
"log"
"time"

"github.com/hashicorp/terraform/helper/validation"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/sagemaker"
"github.com/hashicorp/terraform/helper/resource"
"github.com/hashicorp/terraform/helper/schema"
Expand Down Expand Up @@ -74,6 +72,7 @@ func resourceAwsSagemakerEndpointConfiguration() *schema.Resource {
Optional: true,
ForceNew: true,
ValidateFunc: FloatAtLeast(0),
Default: 1,
},

"accelerator_type": {
Expand All @@ -85,10 +84,11 @@ func resourceAwsSagemakerEndpointConfiguration() *schema.Resource {
},
},

"kms_key_id": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
"kms_key_arn": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
ValidateFunc: validateArn,
},

"tags": tagsSchema(),
Expand All @@ -108,26 +108,21 @@ func resourceAwsSagemakerEndpointConfigurationCreate(d *schema.ResourceData, met

createOpts := &sagemaker.CreateEndpointConfigInput{
EndpointConfigName: aws.String(name),
ProductionVariants: expandSagemakerProductionVariants(d.Get("production_variants").([]interface{})),
}

prodVariants, err := expandProductionVariants(d.Get("production_variants").([]interface{}))
if err != nil {
return err
}
createOpts.ProductionVariants = prodVariants

if v, ok := d.GetOk("kms_key_id"); ok {
if v, ok := d.GetOk("kms_key_arn"); ok {
createOpts.KmsKeyId = aws.String(v.(string))
}

if v, ok := d.GetOk("tags"); ok {
createOpts.Tags = tagsFromMapSagemaker(v.(map[string]interface{}))
}

log.Printf("[DEBUG] Sagemaker endpoint configuration create config: %#v", *createOpts)
_, err = conn.CreateEndpointConfig(createOpts)
log.Printf("[DEBUG] SageMaker Endpoint Configuration create config: %#v", *createOpts)
_, err := conn.CreateEndpointConfig(createOpts)
if err != nil {
return fmt.Errorf("error creating Sagemaker endpoint configuration: %s", err)
return fmt.Errorf("error creating SageMaker Endpoint Configuration: %s", err)
}
d.SetId(name)

Expand All @@ -143,12 +138,12 @@ func resourceAwsSagemakerEndpointConfigurationRead(d *schema.ResourceData, meta

endpointConfig, err := conn.DescribeEndpointConfig(request)
if err != nil {
if sagemakerErr, ok := err.(awserr.Error); ok && sagemakerErr.Code() == "ValidationException" {
log.Printf("[INFO] unable to find the sagemaker endpoint configuration resource and therefore it is removed from the state: %s", d.Id())
if isAWSErr(err, "ValidationException", "") {
log.Printf("[INFO] unable to find the SageMaker Endpoint Configuration resource and therefore it is removed from the state: %s", d.Id())
d.SetId("")
return nil
}
return fmt.Errorf("error reading Sagemaker endpoint configuration %s: %s", d.Id(), err)
return fmt.Errorf("error reading SageMaker Endpoint Configuration %s: %s", d.Id(), err)
}

if err := d.Set("arn", endpointConfig.EndpointConfigArn); err != nil {
Expand All @@ -160,15 +155,15 @@ func resourceAwsSagemakerEndpointConfigurationRead(d *schema.ResourceData, meta
if err := d.Set("production_variants", flattenProductionVariants(endpointConfig.ProductionVariants)); err != nil {
return err
}
if err := d.Set("kms_key_id", endpointConfig.KmsKeyId); err != nil {
if err := d.Set("kms_key_arn", endpointConfig.KmsKeyId); err != nil {
return err
}

tagsOutput, err := conn.ListTags(&sagemaker.ListTagsInput{
ResourceArn: endpointConfig.EndpointConfigArn,
})
if err != nil {
return fmt.Errorf("error listing tags of Sagemaker endpoint configuration %s: %s", d.Id(), err)
return fmt.Errorf("error listing tags of SageMaker Endpoint Configuration %s: %s", d.Id(), err)
}
if err := d.Set("tags", tagsToMapSagemaker(tagsOutput.Tags)); err != nil {
return err
Expand All @@ -179,16 +174,9 @@ func resourceAwsSagemakerEndpointConfigurationRead(d *schema.ResourceData, meta
func resourceAwsSagemakerEndpointConfigurationUpdate(d *schema.ResourceData, meta interface{}) error {
conn := meta.(*AWSClient).sagemakerconn

d.Partial(true)

if err := setSagemakerTags(conn, d); err != nil {
return err
} else {
d.SetPartial("tags")
}

d.Partial(false)

return resourceAwsSagemakerEndpointConfigurationRead(d, meta)
}

Expand All @@ -198,28 +186,22 @@ func resourceAwsSagemakerEndpointConfigurationDelete(d *schema.ResourceData, met
deleteOpts := &sagemaker.DeleteEndpointConfigInput{
EndpointConfigName: aws.String(d.Id()),
}
log.Printf("[INFO] Deleting Sagemaker endpoint configuration: %s", d.Id())
log.Printf("[INFO] Deleting SageMaker Endpoint Configuration: %s", d.Id())

return resource.Retry(5*time.Minute, func() *resource.RetryError {
_, err := conn.DeleteEndpointConfig(deleteOpts)
if err == nil {
return nil
}
_, err := conn.DeleteEndpointConfig(deleteOpts)

sagemakerErr, ok := err.(awserr.Error)
if !ok {
return resource.NonRetryableError(err)
}
if isAWSErr(err, sagemaker.ErrCodeResourceNotFound, "") {
return nil
}

if sagemakerErr.Code() == "ResourceNotFound" {
return resource.RetryableError(err)
}
if err != nil {
return fmt.Errorf("error deleting SageMaker Endpoint Configuration (%s): %s", d.Id(), err)
}

return resource.NonRetryableError(fmt.Errorf("Error deleting Sagemaker endpoint configuration: %s", err))
})
return nil
}

func expandProductionVariants(configured []interface{}) ([]*sagemaker.ProductionVariant, error) {
func expandSagemakerProductionVariants(configured []interface{}) []*sagemaker.ProductionVariant {
containers := make([]*sagemaker.ProductionVariant, 0, len(configured))

for _, lRaw := range configured {
Expand All @@ -239,8 +221,6 @@ func expandProductionVariants(configured []interface{}) ([]*sagemaker.Production

if v, ok := data["initial_variant_weight"]; ok {
l.InitialVariantWeight = aws.Float64(v.(float64))
} else {
l.InitialVariantWeight = aws.Float64(1)
}

if v, ok := data["accelerator_type"]; ok && v.(string) != "" {
Expand All @@ -250,26 +230,20 @@ func expandProductionVariants(configured []interface{}) ([]*sagemaker.Production
containers = append(containers, l)
}

return containers, nil
return containers
}

func flattenProductionVariants(list []*sagemaker.ProductionVariant) []map[string]interface{} {
result := make([]map[string]interface{}, 0, len(list))

for _, i := range list {
l := map[string]interface{}{
"instance_type": *i.InstanceType,
"model_name": *i.ModelName,
"initial_instance_count": *i.InitialInstanceCount,
}
if i.VariantName != nil {
l["variant_name"] = *i.VariantName
}
if i.InitialVariantWeight != nil {
l["initial_variant_weight"] = *i.InitialVariantWeight
}
if i.AcceleratorType != nil {
l["accelerator_type"] = *i.AcceleratorType
"accelerator_type": aws.StringValue(i.AcceleratorType),
"initial_instance_count": aws.Int64Value(i.InitialInstanceCount),
"initial_variant_weight": aws.Float64Value(i.InitialVariantWeight),
"instance_type": aws.StringValue(i.InstanceType),
"model_name": aws.StringValue(i.ModelName),
"variant_name": aws.StringValue(i.VariantName),
}

result = append(result, l)
Expand Down
Loading

0 comments on commit 6ecd49f

Please sign in to comment.