diff --git a/aws/config.go b/aws/config.go index 81e0b5f5a69..8e1c088e7a8 100644 --- a/aws/config.go +++ b/aws/config.go @@ -92,6 +92,7 @@ import ( "github.com/aws/aws-sdk-go/service/route53" "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3control" + "github.com/aws/aws-sdk-go/service/sagemaker" "github.com/aws/aws-sdk-go/service/secretsmanager" "github.com/aws/aws-sdk-go/service/securityhub" "github.com/aws/aws-sdk-go/service/servicecatalog" @@ -203,6 +204,7 @@ type AWSClient struct { apigateway *apigateway.APIGateway appautoscalingconn *applicationautoscaling.ApplicationAutoScaling autoscalingconn *autoscaling.AutoScaling + sagemakerconn *sagemaker.SageMaker s3conn *s3.S3 s3controlconn *s3control.S3Control secretsmanagerconn *secretsmanager.SecretsManager @@ -579,6 +581,7 @@ func (c *Config) Client() (interface{}, error) { client.rdsconn = rds.New(awsRdsSess) client.redshiftconn = redshift.New(sess) client.resourcegroupsconn = resourcegroups.New(sess) + client.sagemakerconn = sagemaker.New(sess) client.simpledbconn = simpledb.New(sess) client.s3conn = s3.New(awsS3Sess) client.s3controlconn = s3control.New(awsS3ControlSess) diff --git a/aws/provider.go b/aws/provider.go index 9fb4cd455ab..752986c2f08 100644 --- a/aws/provider.go +++ b/aws/provider.go @@ -630,6 +630,7 @@ func Provider() terraform.ResourceProvider { "aws_s3_bucket_notification": resourceAwsS3BucketNotification(), "aws_s3_bucket_metric": resourceAwsS3BucketMetric(), "aws_s3_bucket_inventory": resourceAwsS3BucketInventory(), + "aws_sagemaker_notebook_instance": resourceAwsSagemakerNotebookInstance(), "aws_security_group": resourceAwsSecurityGroup(), "aws_network_interface_sg_attachment": resourceAwsNetworkInterfaceSGAttachment(), "aws_default_security_group": resourceAwsDefaultSecurityGroup(), diff --git a/aws/resource_aws_sagemaker_notebook_instance.go b/aws/resource_aws_sagemaker_notebook_instance.go new file mode 100644 index 00000000000..948b01c30e5 --- /dev/null +++ b/aws/resource_aws_sagemaker_notebook_instance.go @@ -0,0 +1,386 @@ +package aws + +import ( + "fmt" + "log" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/sagemaker" + "github.com/hashicorp/terraform/helper/resource" + "github.com/hashicorp/terraform/helper/schema" +) + +func resourceAwsSagemakerNotebookInstance() *schema.Resource { + return &schema.Resource{ + Create: resourceAwsSagemakerNotebookInstanceCreate, + Read: resourceAwsSagemakerNotebookInstanceRead, + Update: resourceAwsSagemakerNotebookInstanceUpdate, + Delete: resourceAwsSagemakerNotebookInstanceDelete, + Importer: &schema.ResourceImporter{ + State: schema.ImportStatePassthrough, + }, + + Schema: map[string]*schema.Schema{ + "arn": { + Type: schema.TypeString, + Computed: true, + }, + + "name": { + Type: schema.TypeString, + Required: true, + ForceNew: true, + ValidateFunc: validateSagemakerName, + }, + + "role_arn": { + Type: schema.TypeString, + Required: true, + }, + + "instance_type": { + Type: schema.TypeString, + Required: true, + }, + + "subnet_id": { + Type: schema.TypeString, + Optional: true, + ForceNew: true, + }, + + "security_groups": { + Type: schema.TypeSet, + MinItems: 1, + Optional: true, + Computed: true, + ForceNew: true, + Elem: &schema.Schema{Type: schema.TypeString}, + Set: schema.HashString, + }, + + "kms_key_id": { + Type: schema.TypeString, + Optional: true, + ForceNew: true, + }, + + "tags": tagsSchema(), + }, + } +} + +func resourceAwsSagemakerNotebookInstanceCreate(d *schema.ResourceData, meta interface{}) error { + conn := meta.(*AWSClient).sagemakerconn + + name := d.Get("name").(string) + + createOpts := &sagemaker.CreateNotebookInstanceInput{ + SecurityGroupIds: expandStringSet(d.Get("security_groups").(*schema.Set)), + NotebookInstanceName: aws.String(name), + RoleArn: aws.String(d.Get("role_arn").(string)), + InstanceType: aws.String(d.Get("instance_type").(string)), + } + + if s, ok := d.GetOk("subnet_id"); ok { + createOpts.SubnetId = aws.String(s.(string)) + } + + if k, ok := d.GetOk("kms_key_id"); ok { + createOpts.KmsKeyId = aws.String(k.(string)) + } + + if v, ok := d.GetOk("tags"); ok { + tagsIn := v.(map[string]interface{}) + createOpts.Tags = tagsFromMapSagemaker(tagsIn) + } + + log.Printf("[DEBUG] sagemaker notebook instance create config: %#v", *createOpts) + _, err := conn.CreateNotebookInstance(createOpts) + if err != nil { + return fmt.Errorf("Error creating Sagemaker Notebook Instance: %s", err) + } + + d.SetId(name) + log.Printf("[INFO] sagemaker notebook instance ID: %s", d.Id()) + + stateConf := &resource.StateChangeConf{ + Pending: []string{ + sagemaker.NotebookInstanceStatusUpdating, + sagemaker.NotebookInstanceStatusPending, + sagemaker.NotebookInstanceStatusStopped, + }, + Target: []string{sagemaker.NotebookInstanceStatusInService}, + Refresh: sagemakerNotebookInstanceStateRefreshFunc(conn, d.Id()), + Timeout: 10 * time.Minute, + } + _, err = stateConf.WaitForState() + if err != nil { + return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to create: %s", d.Id(), err) + } + + return resourceAwsSagemakerNotebookInstanceRead(d, meta) +} + +func resourceAwsSagemakerNotebookInstanceRead(d *schema.ResourceData, meta interface{}) error { + conn := meta.(*AWSClient).sagemakerconn + + describeNotebookInput := &sagemaker.DescribeNotebookInstanceInput{ + NotebookInstanceName: aws.String(d.Id()), + } + notebookInstance, err := conn.DescribeNotebookInstance(describeNotebookInput) + if err != nil { + if isAWSErr(err, "ValidationException", "RecordNotFound") { + d.SetId("") + log.Printf("[WARN] Unable to find sageMaker notebook instance (%s); removing from state", d.Id()) + return nil + } + return fmt.Errorf("error finding sagemaker notebook instance (%s): %s", d.Id(), err) + + } + + if err := d.Set("security_groups", flattenStringList(notebookInstance.SecurityGroups)); err != nil { + return fmt.Errorf("error setting security groups for sagemaker notebook instance (%s): %s", d.Id(), err) + } + if err := d.Set("name", notebookInstance.NotebookInstanceName); err != nil { + return fmt.Errorf("error setting name for sagemaker notebook instance (%s): %s", d.Id(), err) + } + if err := d.Set("role_arn", notebookInstance.RoleArn); err != nil { + return fmt.Errorf("error setting role_arn for sagemaker notebook instance (%s): %s", d.Id(), err) + } + if err := d.Set("instance_type", notebookInstance.InstanceType); err != nil { + return fmt.Errorf("error setting instance_type for sagemaker notebook instance (%s): %s", d.Id(), err) + } + if err := d.Set("subnet_id", notebookInstance.SubnetId); err != nil { + return fmt.Errorf("error setting subnet_id for sagemaker notebook instance (%s): %s", d.Id(), err) + } + + if err := d.Set("kms_key_id", notebookInstance.KmsKeyId); err != nil { + return fmt.Errorf("error setting kms_key_id for sagemaker notebook instance (%s): %s", d.Id(), err) + } + + if err := d.Set("arn", notebookInstance.NotebookInstanceArn); err != nil { + return fmt.Errorf("error setting arn for sagemaker notebook instance (%s): %s", d.Id(), err) + } + tagsOutput, err := conn.ListTags(&sagemaker.ListTagsInput{ + ResourceArn: notebookInstance.NotebookInstanceArn, + }) + if err != nil { + return fmt.Errorf("error listing tags for sagemaker notebook instance (%s): %s", d.Id(), err) + } + + if err := d.Set("tags", tagsToMapSagemaker(tagsOutput.Tags)); err != nil { + return fmt.Errorf("error setting tags for notebook instance (%s): %s", d.Id(), err) + } + return nil +} + +func resourceAwsSagemakerNotebookInstanceUpdate(d *schema.ResourceData, meta interface{}) error { + conn := meta.(*AWSClient).sagemakerconn + + d.Partial(true) + + if err := setSagemakerTags(conn, d); err != nil { + return err + } + d.SetPartial("tags") + + hasChanged := false + // Update + updateOpts := &sagemaker.UpdateNotebookInstanceInput{ + NotebookInstanceName: aws.String(d.Get("name").(string)), + } + + if d.HasChange("role_arn") { + updateOpts.RoleArn = aws.String(d.Get("role_arn").(string)) + hasChanged = true + } + + if d.HasChange("instance_type") { + updateOpts.InstanceType = aws.String(d.Get("instance_type").(string)) + hasChanged = true + } + + if hasChanged { + + // Stop notebook + _, previousStatus, _ := sagemakerNotebookInstanceStateRefreshFunc(conn, d.Id())() + if previousStatus != sagemaker.NotebookInstanceStatusStopped { + if err := stopSagemakerNotebookInstance(conn, d.Id()); err != nil { + return fmt.Errorf("error stopping sagemaker notebook instance prior to updating: %s", err) + } + } + + if _, err := conn.UpdateNotebookInstance(updateOpts); err != nil { + return fmt.Errorf("error updating sagemaker notebook instance: %s", err) + } + + stateConf := &resource.StateChangeConf{ + Pending: []string{ + sagemaker.NotebookInstanceStatusUpdating, + }, + Target: []string{sagemaker.NotebookInstanceStatusStopped}, + Refresh: sagemakerNotebookInstanceStateRefreshFunc(conn, d.Id()), + Timeout: 10 * time.Minute, + } + _, err := stateConf.WaitForState() + if err != nil { + return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to update: %s", d.Id(), err) + } + + // Restart if needed + if previousStatus == sagemaker.NotebookInstanceStatusInService { + startOpts := &sagemaker.StartNotebookInstanceInput{ + NotebookInstanceName: aws.String(d.Id()), + } + + // StartNotebookInstance sometimes doesn't take so we'll check for a state change and if + // it doesn't change we'll send another request + err := resource.Retry(5*time.Minute, func() *resource.RetryError { + if _, err := conn.StartNotebookInstance(startOpts); err != nil { + return resource.NonRetryableError(fmt.Errorf("error starting sagemaker notebook instance (%s): %s", d.Id(), err)) + } + stateConf := &resource.StateChangeConf{ + Pending: []string{ + sagemaker.NotebookInstanceStatusStopped, + }, + Target: []string{sagemaker.NotebookInstanceStatusInService, sagemaker.NotebookInstanceStatusPending}, + Refresh: sagemakerNotebookInstanceStateRefreshFunc(conn, d.Id()), + Timeout: 30 * time.Second, + } + _, err := stateConf.WaitForState() + if err != nil { + return resource.RetryableError(fmt.Errorf("error waiting for sagemaker notebook instance (%s) to start: %s", d.Id(), err)) + } + + return nil + }) + if err != nil { + return err + } + + stateConf := &resource.StateChangeConf{ + Pending: []string{ + sagemaker.NotebookInstanceStatusUpdating, + sagemaker.NotebookInstanceStatusPending, + sagemaker.NotebookInstanceStatusStopped, + }, + Target: []string{sagemaker.NotebookInstanceStatusInService}, + Refresh: sagemakerNotebookInstanceStateRefreshFunc(conn, d.Id()), + Timeout: 10 * time.Minute, + } + _, err = stateConf.WaitForState() + if err != nil { + return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to start after update: %s", d.Id(), err) + } + } + } + + d.Partial(false) + + return resourceAwsSagemakerNotebookInstanceRead(d, meta) +} + +func resourceAwsSagemakerNotebookInstanceDelete(d *schema.ResourceData, meta interface{}) error { + conn := meta.(*AWSClient).sagemakerconn + + describeNotebookInput := &sagemaker.DescribeNotebookInstanceInput{ + NotebookInstanceName: aws.String(d.Id()), + } + notebook, err := conn.DescribeNotebookInstance(describeNotebookInput) + if err != nil { + if isAWSErr(err, "ValidationException", "RecordNotFound") { + return nil + } + return fmt.Errorf("unable to find sagemaker notebook instance to delete (%s): %s", d.Id(), err) + } + if *notebook.NotebookInstanceStatus != sagemaker.NotebookInstanceStatusFailed && *notebook.NotebookInstanceStatus != sagemaker.NotebookInstanceStatusStopped { + if err := stopSagemakerNotebookInstance(conn, d.Id()); err != nil { + return err + } + } + + deleteOpts := &sagemaker.DeleteNotebookInstanceInput{ + NotebookInstanceName: aws.String(d.Id()), + } + + if _, err := conn.DeleteNotebookInstance(deleteOpts); err != nil { + return fmt.Errorf("error trying to delete sagemaker notebook instance (%s): %s", d.Id(), err) + } + + stateConf := &resource.StateChangeConf{ + Pending: []string{ + sagemaker.NotebookInstanceStatusDeleting, + }, + Target: []string{""}, + Refresh: sagemakerNotebookInstanceStateRefreshFunc(conn, d.Id()), + Timeout: 10 * time.Minute, + } + _, err = stateConf.WaitForState() + if err != nil { + return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to delete: %s", d.Id(), err) + } + + return nil +} + +func stopSagemakerNotebookInstance(conn *sagemaker.SageMaker, id string) error { + describeNotebookInput := &sagemaker.DescribeNotebookInstanceInput{ + NotebookInstanceName: aws.String(id), + } + notebook, err := conn.DescribeNotebookInstance(describeNotebookInput) + if err != nil { + if isAWSErr(err, "ValidationException", "RecordNotFound") { + return nil + } + return fmt.Errorf("unable to find sagemaker notebook instance (%s): %s", id, err) + } + if *notebook.NotebookInstanceStatus == sagemaker.NotebookInstanceStatusStopped { + return nil + } + + stopOpts := &sagemaker.StopNotebookInstanceInput{ + NotebookInstanceName: aws.String(id), + } + + if _, err := conn.StopNotebookInstance(stopOpts); err != nil { + return fmt.Errorf("Error stopping sagemaker notebook instance: %s", err) + } + + stateConf := &resource.StateChangeConf{ + Pending: []string{ + sagemaker.NotebookInstanceStatusStopping, + }, + Target: []string{sagemaker.NotebookInstanceStatusStopped}, + Refresh: sagemakerNotebookInstanceStateRefreshFunc(conn, id), + Timeout: 10 * time.Minute, + } + _, err = stateConf.WaitForState() + if err != nil { + return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to stop: %s", id, err) + } + + return nil +} + +func sagemakerNotebookInstanceStateRefreshFunc(conn *sagemaker.SageMaker, name string) resource.StateRefreshFunc { + return func() (interface{}, string, error) { + describeNotebookInput := &sagemaker.DescribeNotebookInstanceInput{ + NotebookInstanceName: aws.String(name), + } + notebook, err := conn.DescribeNotebookInstance(describeNotebookInput) + if err != nil { + if isAWSErr(err, "ValidationException", "RecordNotFound") { + return 1, "", nil + } + return nil, "", err + } + + if notebook == nil { + return nil, "", nil + } + + return notebook, *notebook.NotebookInstanceStatus, nil + } +} diff --git a/aws/resource_aws_sagemaker_notebook_instance_test.go b/aws/resource_aws_sagemaker_notebook_instance_test.go new file mode 100644 index 00000000000..d4a1f0d27ca --- /dev/null +++ b/aws/resource_aws_sagemaker_notebook_instance_test.go @@ -0,0 +1,373 @@ +package aws + +import ( + "fmt" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/sagemaker" + "github.com/hashicorp/terraform/helper/resource" + "github.com/hashicorp/terraform/terraform" +) + +const sagemakerTestAccSagemakerNotebookInstanceResourceNamePrefix = "terraform-testacc-" + +func TestAccAWSSagemakerNotebookInstance_basic(t *testing.T) { + var notebook sagemaker.DescribeNotebookInstanceOutput + notebookName := resource.PrefixedUniqueId(sagemakerTestAccSagemakerNotebookInstanceResourceNamePrefix) + var resourceName = "aws_sagemaker_notebook_instance.foo" + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + Providers: testAccProviders, + CheckDestroy: testAccCheckAWSSagemakerNotebookInstanceDestroy, + Steps: []resource.TestStep{ + { + Config: testAccAWSSagemakerNotebookInstanceConfig(notebookName), + Check: resource.ComposeTestCheckFunc( + testAccCheckAWSSagemakerNotebookInstanceExists(resourceName, ¬ebook), + testAccCheckAWSSagemakerNotebookInstanceName(¬ebook, notebookName), + + resource.TestCheckResourceAttr( + "aws_sagemaker_notebook_instance.foo", "name", notebookName), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +func TestAccAWSSagemakerNotebookInstance_update(t *testing.T) { + var notebook sagemaker.DescribeNotebookInstanceOutput + notebookName := resource.PrefixedUniqueId(sagemakerTestAccSagemakerNotebookInstanceResourceNamePrefix) + var resourceName = "aws_sagemaker_notebook_instance.foo" + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + Providers: testAccProviders, + CheckDestroy: testAccCheckAWSSagemakerNotebookInstanceDestroy, + Steps: []resource.TestStep{ + { + Config: testAccAWSSagemakerNotebookInstanceConfig(notebookName), + Check: resource.ComposeTestCheckFunc( + testAccCheckAWSSagemakerNotebookInstanceExists(resourceName, ¬ebook), + + resource.TestCheckResourceAttr( + "aws_sagemaker_notebook_instance.foo", "instance_type", "ml.t2.medium"), + ), + }, + + { + Config: testAccAWSSagemakerNotebookInstanceUpdateConfig(notebookName), + Check: resource.ComposeTestCheckFunc( + testAccCheckAWSSagemakerNotebookInstanceExists("aws_sagemaker_notebook_instance.foo", ¬ebook), + + resource.TestCheckResourceAttr( + "aws_sagemaker_notebook_instance.foo", "instance_type", "ml.m4.xlarge"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +func TestAccAWSSagemakerNotebookInstance_tags(t *testing.T) { + var notebook sagemaker.DescribeNotebookInstanceOutput + notebookName := resource.PrefixedUniqueId(sagemakerTestAccSagemakerNotebookInstanceResourceNamePrefix) + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + Providers: testAccProviders, + CheckDestroy: testAccCheckAWSSagemakerNotebookInstanceDestroy, + Steps: []resource.TestStep{ + { + Config: testAccAWSSagemakerNotebookInstanceTagsConfig(notebookName), + Check: resource.ComposeTestCheckFunc( + testAccCheckAWSSagemakerNotebookInstanceExists("aws_sagemaker_notebook_instance.foo", ¬ebook), + testAccCheckAWSSagemakerNotebookInstanceTags(¬ebook, "foo", "bar"), + + resource.TestCheckResourceAttr( + "aws_sagemaker_notebook_instance.foo", "name", notebookName), + resource.TestCheckResourceAttr("aws_sagemaker_notebook_instance.foo", "tags.%", "1"), + resource.TestCheckResourceAttr("aws_sagemaker_notebook_instance.foo", "tags.foo", "bar"), + ), + }, + + { + Config: testAccAWSSagemakerNotebookInstanceTagsUpdateConfig(notebookName), + Check: resource.ComposeTestCheckFunc( + testAccCheckAWSSagemakerNotebookInstanceExists("aws_sagemaker_notebook_instance.foo", ¬ebook), + testAccCheckAWSSagemakerNotebookInstanceTags(¬ebook, "foo", ""), + testAccCheckAWSSagemakerNotebookInstanceTags(¬ebook, "bar", "baz"), + + resource.TestCheckResourceAttr("aws_sagemaker_notebook_instance.foo", "tags.%", "1"), + resource.TestCheckResourceAttr("aws_sagemaker_notebook_instance.foo", "tags.bar", "baz"), + ), + }, + }, + }) +} + +func TestAccAWSSagemakerNotebookInstance_disappears(t *testing.T) { + var notebook sagemaker.DescribeNotebookInstanceOutput + notebookName := resource.PrefixedUniqueId(sagemakerTestAccSagemakerNotebookInstanceResourceNamePrefix) + var resourceName = "aws_sagemaker_notebook_instance.foo" + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + Providers: testAccProviders, + CheckDestroy: testAccCheckAWSSagemakerNotebookInstanceDestroy, + Steps: []resource.TestStep{ + { + Config: testAccAWSSagemakerNotebookInstanceConfig(notebookName), + Check: resource.ComposeTestCheckFunc( + testAccCheckAWSSagemakerNotebookInstanceExists(resourceName, ¬ebook), + testAccCheckAWSSagemakerNotebookInstanceDisappears(¬ebook), + ), + ExpectNonEmptyPlan: true, + }, + }, + }) +} + +func testAccCheckAWSSagemakerNotebookInstanceDestroy(s *terraform.State) error { + conn := testAccProvider.Meta().(*AWSClient).sagemakerconn + + for _, rs := range s.RootModule().Resources { + if rs.Type != "aws_sagemaker_notebook_instance" { + continue + } + + describeNotebookInput := &sagemaker.DescribeNotebookInstanceInput{ + NotebookInstanceName: aws.String(rs.Primary.ID), + } + notebookInstance, err := conn.DescribeNotebookInstance(describeNotebookInput) + if err != nil { + return nil + } + + if *notebookInstance.NotebookInstanceName == rs.Primary.ID { + return fmt.Errorf("sagemaker notebook instance %q still exists", rs.Primary.ID) + } + } + + return nil +} + +func testAccCheckAWSSagemakerNotebookInstanceExists(n string, notebook *sagemaker.DescribeNotebookInstanceOutput) resource.TestCheckFunc { + return func(s *terraform.State) error { + rs, ok := s.RootModule().Resources[n] + if !ok { + return fmt.Errorf("Not found: %s", n) + } + + if rs.Primary.ID == "" { + return fmt.Errorf("No sagmaker Notebook Instance ID is set") + } + + conn := testAccProvider.Meta().(*AWSClient).sagemakerconn + opts := &sagemaker.DescribeNotebookInstanceInput{ + NotebookInstanceName: aws.String(rs.Primary.ID), + } + resp, err := conn.DescribeNotebookInstance(opts) + if err != nil { + return err + } + + *notebook = *resp + + return nil + } +} + +func testAccCheckAWSSagemakerNotebookInstanceDisappears(instance *sagemaker.DescribeNotebookInstanceOutput) resource.TestCheckFunc { + return func(s *terraform.State) error { + conn := testAccProvider.Meta().(*AWSClient).sagemakerconn + + if *instance.NotebookInstanceStatus != sagemaker.NotebookInstanceStatusFailed && *instance.NotebookInstanceStatus != sagemaker.NotebookInstanceStatusStopped { + if err := stopSagemakerNotebookInstance(conn, *instance.NotebookInstanceName); err != nil { + return err + } + } + + deleteOpts := &sagemaker.DeleteNotebookInstanceInput{ + NotebookInstanceName: instance.NotebookInstanceName, + } + + if _, err := conn.DeleteNotebookInstance(deleteOpts); err != nil { + return fmt.Errorf("error trying to delete sagemaker notebook instance (%s): %s", aws.StringValue(instance.NotebookInstanceName), err) + } + + stateConf := &resource.StateChangeConf{ + Pending: []string{ + sagemaker.NotebookInstanceStatusDeleting, + }, + Target: []string{""}, + Refresh: sagemakerNotebookInstanceStateRefreshFunc(conn, *instance.NotebookInstanceName), + Timeout: 10 * time.Minute, + } + _, err := stateConf.WaitForState() + if err != nil { + return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to delete: %s", aws.StringValue(instance.NotebookInstanceName), err) + } + + return nil + } +} + +func testAccCheckAWSSagemakerNotebookInstanceName(notebook *sagemaker.DescribeNotebookInstanceOutput, expected string) resource.TestCheckFunc { + return func(s *terraform.State) error { + notebookName := notebook.NotebookInstanceName + if *notebookName != expected { + return fmt.Errorf("Bad Notebook Instance name: %s", *notebook.NotebookInstanceName) + } + + return nil + } +} + +func testAccCheckAWSSagemakerNotebookInstanceTags(notebook *sagemaker.DescribeNotebookInstanceOutput, key string, value string) resource.TestCheckFunc { + return func(s *terraform.State) error { + conn := testAccProvider.Meta().(*AWSClient).sagemakerconn + + ts, err := conn.ListTags(&sagemaker.ListTagsInput{ + ResourceArn: notebook.NotebookInstanceArn, + }) + if err != nil { + return fmt.Errorf("Error listing tags: %s", err) + } + + m := tagsToMapSagemaker(ts.Tags) + v, ok := m[key] + if value != "" && !ok { + return fmt.Errorf("Missing tag: %s", key) + } else if value == "" && ok { + return fmt.Errorf("Extra tag: %s", key) + } + if value == "" { + return nil + } + + if v != value { + return fmt.Errorf("%s: bad value: %s", key, v) + } + + return nil + } +} + +func testAccAWSSagemakerNotebookInstanceConfig(notebookName string) string { + return fmt.Sprintf(` +resource "aws_sagemaker_notebook_instance" "foo" { + name = "%s" + role_arn = "${aws_iam_role.foo.arn}" + instance_type = "ml.t2.medium" +} + +resource "aws_iam_role" "foo" { + name = "%s" + path = "/" + assume_role_policy = "${data.aws_iam_policy_document.assume_role.json}" +} + +data "aws_iam_policy_document" "assume_role" { + statement { + actions = [ "sts:AssumeRole" ] + principals { + type = "Service" + identifiers = [ "sagemaker.amazonaws.com" ] + } + } +} +`, notebookName, notebookName) +} + +func testAccAWSSagemakerNotebookInstanceUpdateConfig(notebookName string) string { + return fmt.Sprintf(` +resource "aws_sagemaker_notebook_instance" "foo" { + name = "%s" + role_arn = "${aws_iam_role.foo.arn}" + instance_type = "ml.m4.xlarge" +} + +resource "aws_iam_role" "foo" { + name = "%s" + path = "/" + assume_role_policy = "${data.aws_iam_policy_document.assume_role.json}" +} + +data "aws_iam_policy_document" "assume_role" { + statement { + actions = [ "sts:AssumeRole" ] + principals { + type = "Service" + identifiers = [ "sagemaker.amazonaws.com" ] + } + } +} +`, notebookName, notebookName) +} + +func testAccAWSSagemakerNotebookInstanceTagsConfig(notebookName string) string { + return fmt.Sprintf(` +resource "aws_sagemaker_notebook_instance" "foo" { + name = "%s" + role_arn = "${aws_iam_role.foo.arn}" + instance_type = "ml.t2.medium" + tags { + foo = "bar" + } +} + +resource "aws_iam_role" "foo" { + name = "%s" + path = "/" + assume_role_policy = "${data.aws_iam_policy_document.assume_role.json}" +} + +data "aws_iam_policy_document" "assume_role" { + statement { + actions = [ "sts:AssumeRole" ] + principals { + type = "Service" + identifiers = [ "sagemaker.amazonaws.com" ] + } + } +} +`, notebookName, notebookName) +} + +func testAccAWSSagemakerNotebookInstanceTagsUpdateConfig(notebookName string) string { + return fmt.Sprintf(` +resource "aws_sagemaker_notebook_instance" "foo" { + name = "%s" + role_arn = "${aws_iam_role.foo.arn}" + instance_type = "ml.t2.medium" + tags { + bar = "baz" + } +} + +resource "aws_iam_role" "foo" { + name = "%s" + path = "/" + assume_role_policy = "${data.aws_iam_policy_document.assume_role.json}" +} + +data "aws_iam_policy_document" "assume_role" { + statement { + actions = [ "sts:AssumeRole" ] + principals { + type = "Service" + identifiers = [ "sagemaker.amazonaws.com" ] + } + } +} +`, notebookName, notebookName) +} diff --git a/aws/tags_sagemaker.go b/aws/tags_sagemaker.go new file mode 100644 index 00000000000..cd0e0ee9eaa --- /dev/null +++ b/aws/tags_sagemaker.go @@ -0,0 +1,120 @@ +package aws + +import ( + "log" + "regexp" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/service/sagemaker" + "github.com/hashicorp/terraform/helper/resource" + "github.com/hashicorp/terraform/helper/schema" +) + +func tagsFromMapSagemaker(m map[string]interface{}) []*sagemaker.Tag { + result := make([]*sagemaker.Tag, 0, len(m)) + for k, v := range m { + t := &sagemaker.Tag{ + Key: aws.String(k), + Value: aws.String(v.(string)), + } + if !tagIgnoredSagemaker(t) { + result = append(result, t) + } + } + + return result +} + +func tagsToMapSagemaker(ts []*sagemaker.Tag) map[string]string { + result := make(map[string]string) + for _, t := range ts { + if !tagIgnoredSagemaker(t) { + result[*t.Key] = *t.Value + } + } + + return result +} + +func setSagemakerTags(conn *sagemaker.SageMaker, d *schema.ResourceData) error { + if d.HasChange("tags") { + oraw, nraw := d.GetChange("tags") + o := oraw.(map[string]interface{}) + n := nraw.(map[string]interface{}) + create, remove := diffSagemakerTags(tagsFromMapSagemaker(o), tagsFromMapSagemaker(n)) + + if len(remove) > 0 { + err := resource.Retry(5*time.Minute, func() *resource.RetryError { + log.Printf("[DEBUG] Removing tags: %#v from %s", remove, d.Id()) + _, err := conn.DeleteTags(&sagemaker.DeleteTagsInput{ + ResourceArn: aws.String(d.Get("arn").(string)), + TagKeys: remove, + }) + if err != nil { + sagemakerErr, ok := err.(awserr.Error) + if ok && sagemakerErr.Code() == "ResourceNotFound" { + return resource.RetryableError(err) + } + return resource.NonRetryableError(err) + } + return nil + }) + if err != nil { + return err + } + } + if len(create) > 0 { + err := resource.Retry(5*time.Minute, func() *resource.RetryError { + log.Printf("[DEBUG] Creating tags: %s for %s", create, d.Id()) + _, err := conn.AddTags(&sagemaker.AddTagsInput{ + ResourceArn: aws.String(d.Get("arn").(string)), + Tags: create, + }) + if err != nil { + sagemakerErr, ok := err.(awserr.Error) + if ok && sagemakerErr.Code() == "ResourceNotFound" { + return resource.RetryableError(err) + } + return resource.NonRetryableError(err) + } + return nil + }) + if err != nil { + return err + } + } + } + + return nil +} + +func diffSagemakerTags(oldTags, newTags []*sagemaker.Tag) ([]*sagemaker.Tag, []*string) { + create := make(map[string]interface{}) + for _, t := range newTags { + create[*t.Key] = *t.Value + } + + var remove []*string + for _, t := range oldTags { + old, ok := create[*t.Key] + if !ok || old != *t.Value { + remove = append(remove, t.Key) + } + } + + return tagsFromMapSagemaker(create), remove +} + +func tagIgnoredSagemaker(t *sagemaker.Tag) bool { + filter := []string{"^aws:"} + for _, v := range filter { + log.Printf("[DEBUG] Matching %v with %v\n", v, *t.Key) + if r, _ := regexp.MatchString(v, *t.Key); r == true { + log.Printf("[DEBUG] Found AWS specific tag %s (val: %s), ignoring.\n", *t.Key, *t.Value) + return true + } + } + return false +} diff --git a/aws/validators.go b/aws/validators.go index 0347235089c..a53f0182da1 100644 --- a/aws/validators.go +++ b/aws/validators.go @@ -727,6 +727,24 @@ func validateS3BucketLifecycleTransitionStorageClass() schema.SchemaValidateFunc }, false) } +func validateSagemakerName(v interface{}, k string) (ws []string, errors []error) { + value := v.(string) + if !regexp.MustCompile(`^[0-9A-Za-z-]+$`).MatchString(value) { + errors = append(errors, fmt.Errorf( + "only alphanumeric characters and hyphens allowed in %q: %q", + k, value)) + } + if len(value) > 63 { + errors = append(errors, fmt.Errorf( + "%q cannot be longer than 63 characters: %q", k, value)) + } + if regexp.MustCompile(`^-`).MatchString(value) { + errors = append(errors, fmt.Errorf( + "%q cannot begin with a hyphen: %q", k, value)) + } + return +} + func validateDbEventSubscriptionName(v interface{}, k string) (ws []string, errors []error) { value := v.(string) if !regexp.MustCompile(`^[0-9A-Za-z-]+$`).MatchString(value) { diff --git a/aws/validators_test.go b/aws/validators_test.go index 53bed5bc920..02ec45c78df 100644 --- a/aws/validators_test.go +++ b/aws/validators_test.go @@ -668,6 +668,35 @@ func TestValidateS3BucketLifecycleTimestamp(t *testing.T) { } } +func TestValidateSagemakerName(t *testing.T) { + validNames := []string{ + "ValidSageMakerName", + "Valid-5a63Mak3r-Name", + "123-456-789", + "1234", + strings.Repeat("W", 63), + } + for _, v := range validNames { + _, errors := validateSagemakerName(v, "name") + if len(errors) != 0 { + t.Fatalf("%q should be a valid SageMaker name with maximum length 63 chars: %q", v, errors) + } + } + + invalidNames := []string{ + "Invalid name", // blanks are not allowed + "1#{}nook", // other non-alphanumeric chars + "-nook", // cannot start with hyphen + strings.Repeat("W", 64), // length > 63 + } + for _, v := range invalidNames { + _, errors := validateSagemakerName(v, "name") + if len(errors) == 0 { + t.Fatalf("%q should be an invalid SageMaker name", v) + } + } +} + func TestValidateIntegerInSlice(t *testing.T) { cases := []struct { val interface{} diff --git a/website/aws.erb b/website/aws.erb index 9d3894c07ba..960b28ce3bb 100644 --- a/website/aws.erb +++ b/website/aws.erb @@ -2220,6 +2220,17 @@ +