Skip to content

Commit

Permalink
Merge pull request #15385 from DrFaust92/r/sagemaker_instance_refactor
Browse files Browse the repository at this point in the history
r/sagemaker_notebook_instance - `lifecycle_config_name`, `root_access`, and `default_code_repository` allow updating + refactor tests
  • Loading branch information
breathingdust authored Oct 22, 2020
2 parents 6f1a4ee + e1ec08f commit 181ffb0
Show file tree
Hide file tree
Showing 4 changed files with 444 additions and 483 deletions.
37 changes: 37 additions & 0 deletions aws/internal/service/sagemaker/waiter/status.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package waiter

import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sagemaker"
"github.com/hashicorp/aws-sdk-go-base/tfawserr"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource"
)

const (
SagemakerNotebookInstanceStatusNotFound = "NotFound"
)

// NotebookInstanceStatus fetches the NotebookInstance and its Status
func NotebookInstanceStatus(conn *sagemaker.SageMaker, notebookName string) resource.StateRefreshFunc {
return func() (interface{}, string, error) {
input := &sagemaker.DescribeNotebookInstanceInput{
NotebookInstanceName: aws.String(notebookName),
}

output, err := conn.DescribeNotebookInstance(input)

if tfawserr.ErrMessageContains(err, "ValidationException", "RecordNotFound") {
return nil, SagemakerNotebookInstanceStatusNotFound, nil
}

if err != nil {
return nil, sagemaker.NotebookInstanceStatusFailed, err
}

if output == nil {
return nil, SagemakerNotebookInstanceStatusNotFound, nil
}

return output, aws.StringValue(output.NotebookInstanceStatus), nil
}
}
78 changes: 78 additions & 0 deletions aws/internal/service/sagemaker/waiter/waiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package waiter

import (
"time"

"github.com/aws/aws-sdk-go/service/sagemaker"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource"
)

const (
NotebookInstanceInServiceTimeout = 10 * time.Minute
NotebookInstanceStoppedTimeout = 10 * time.Minute
NotebookInstanceDeletedTimeout = 10 * time.Minute
)

// NotebookInstanceInService waits for a NotebookInstance to return InService
func NotebookInstanceInService(conn *sagemaker.SageMaker, notebookName string) (*sagemaker.DescribeNotebookInstanceOutput, error) {
stateConf := &resource.StateChangeConf{
Pending: []string{
SagemakerNotebookInstanceStatusNotFound,
sagemaker.NotebookInstanceStatusUpdating,
sagemaker.NotebookInstanceStatusPending,
sagemaker.NotebookInstanceStatusStopped,
},
Target: []string{sagemaker.NotebookInstanceStatusInService},
Refresh: NotebookInstanceStatus(conn, notebookName),
Timeout: NotebookInstanceInServiceTimeout,
}

outputRaw, err := stateConf.WaitForState()

if output, ok := outputRaw.(*sagemaker.DescribeNotebookInstanceOutput); ok {
return output, err
}

return nil, err
}

// NotebookInstanceStopped waits for a NotebookInstance to return Stopped
func NotebookInstanceStopped(conn *sagemaker.SageMaker, notebookName string) (*sagemaker.DescribeNotebookInstanceOutput, error) {
stateConf := &resource.StateChangeConf{
Pending: []string{
sagemaker.NotebookInstanceStatusUpdating,
sagemaker.NotebookInstanceStatusStopping,
},
Target: []string{sagemaker.NotebookInstanceStatusStopped},
Refresh: NotebookInstanceStatus(conn, notebookName),
Timeout: NotebookInstanceStoppedTimeout,
}

outputRaw, err := stateConf.WaitForState()

if output, ok := outputRaw.(*sagemaker.DescribeNotebookInstanceOutput); ok {
return output, err
}

return nil, err
}

// NotebookInstanceDeleted waits for a NotebookInstance to return Deleted
func NotebookInstanceDeleted(conn *sagemaker.SageMaker, notebookName string) (*sagemaker.DescribeNotebookInstanceOutput, error) {
stateConf := &resource.StateChangeConf{
Pending: []string{
sagemaker.NotebookInstanceStatusDeleting,
},
Target: []string{},
Refresh: NotebookInstanceStatus(conn, notebookName),
Timeout: NotebookInstanceDeletedTimeout,
}

outputRaw, err := stateConf.WaitForState()

if output, ok := outputRaw.(*sagemaker.DescribeNotebookInstanceOutput); ok {
return output, err
}

return nil, err
}
144 changes: 63 additions & 81 deletions aws/resource_aws_sagemaker_notebook_instance.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
package aws

import (
"context"
"fmt"
"log"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sagemaker"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/customdiff"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation"
"github.com/terraform-providers/terraform-provider-aws/aws/internal/keyvaluetags"
"github.com/terraform-providers/terraform-provider-aws/aws/internal/service/sagemaker/waiter"
)

func resourceAwsSagemakerNotebookInstance() *schema.Resource {
Expand All @@ -22,6 +25,11 @@ func resourceAwsSagemakerNotebookInstance() *schema.Resource {
Importer: &schema.ResourceImporter{
State: schema.ImportStatePassthrough,
},
CustomizeDiff: customdiff.Sequence(
customdiff.ForceNewIfChange("volume_size", func(_ context.Context, old, new, meta interface{}) bool {
return new.(int) < old.(int)
}),
),

Schema: map[string]*schema.Schema{
"arn": {
Expand All @@ -37,13 +45,15 @@ func resourceAwsSagemakerNotebookInstance() *schema.Resource {
},

"role_arn": {
Type: schema.TypeString,
Required: true,
Type: schema.TypeString,
Required: true,
ValidateFunc: validateArn,
},

"instance_type": {
Type: schema.TypeString,
Required: true,
Type: schema.TypeString,
Required: true,
ValidateFunc: validation.StringInSlice(sagemaker.InstanceType_Values(), false),
},

"volume_size": {
Expand Down Expand Up @@ -77,33 +87,26 @@ func resourceAwsSagemakerNotebookInstance() *schema.Resource {
"lifecycle_config_name": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
},

"root_access": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
Default: sagemaker.RootAccessEnabled,
ValidateFunc: validation.StringInSlice(
sagemaker.RootAccess_Values(), false),
Type: schema.TypeString,
Optional: true,
Default: sagemaker.RootAccessEnabled,
ValidateFunc: validation.StringInSlice(sagemaker.RootAccess_Values(), false),
},

"direct_internet_access": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
Default: sagemaker.DirectInternetAccessEnabled,
ValidateFunc: validation.StringInSlice([]string{
sagemaker.DirectInternetAccessDisabled,
sagemaker.DirectInternetAccessEnabled,
}, false),
Type: schema.TypeString,
Optional: true,
ForceNew: true,
Default: sagemaker.DirectInternetAccessEnabled,
ValidateFunc: validation.StringInSlice(sagemaker.DirectInternetAccess_Values(), false),
},

"default_code_repository": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
},

"tags": tagsSchema(),
Expand Down Expand Up @@ -164,19 +167,8 @@ func resourceAwsSagemakerNotebookInstanceCreate(d *schema.ResourceData, meta int
d.SetId(name)
log.Printf("[INFO] sagemaker notebook instance ID: %s", d.Id())

stateConf := &resource.StateChangeConf{
Pending: []string{
sagemaker.NotebookInstanceStatusUpdating,
sagemaker.NotebookInstanceStatusPending,
sagemaker.NotebookInstanceStatusStopped,
},
Target: []string{sagemaker.NotebookInstanceStatusInService},
Refresh: sagemakerNotebookInstanceStateRefreshFunc(conn, d.Id()),
Timeout: 10 * time.Minute,
}
_, err = stateConf.WaitForState()
if err != nil {
return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to create: %s", d.Id(), err)
if _, err := waiter.NotebookInstanceInService(conn, d.Id()); err != nil {
return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to create: %w", d.Id(), err)
}

return resourceAwsSagemakerNotebookInstanceRead(d, meta)
Expand Down Expand Up @@ -289,6 +281,29 @@ func resourceAwsSagemakerNotebookInstanceUpdate(d *schema.ResourceData, meta int
hasChanged = true
}

if d.HasChange("lifecycle_config_name") {
if v, ok := d.GetOk("lifecycle_config_name"); ok {
updateOpts.LifecycleConfigName = aws.String(v.(string))
} else {
updateOpts.DisassociateLifecycleConfig = aws.Bool(true)
}
hasChanged = true
}

if d.HasChange("default_code_repository") {
if v, ok := d.GetOk("default_code_repository"); ok {
updateOpts.DefaultCodeRepository = aws.String(v.(string))
} else {
updateOpts.DisassociateDefaultCodeRepository = aws.Bool(true)
}
hasChanged = true
}

if d.HasChange("root_access") {
updateOpts.RootAccess = aws.String(d.Get("root_access").(string))
hasChanged = true
}

if hasChanged {

// Stop notebook
Expand All @@ -303,17 +318,8 @@ func resourceAwsSagemakerNotebookInstanceUpdate(d *schema.ResourceData, meta int
return fmt.Errorf("error updating sagemaker notebook instance: %s", err)
}

stateConf := &resource.StateChangeConf{
Pending: []string{
sagemaker.NotebookInstanceStatusUpdating,
},
Target: []string{sagemaker.NotebookInstanceStatusStopped},
Refresh: sagemakerNotebookInstanceStateRefreshFunc(conn, d.Id()),
Timeout: 10 * time.Minute,
}
_, err := stateConf.WaitForState()
if err != nil {
return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to update: %s", d.Id(), err)
if _, err := waiter.NotebookInstanceStopped(conn, d.Id()); err != nil {
return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to stop: %w", d.Id(), err)
}

// Restart if needed
Expand Down Expand Up @@ -356,19 +362,8 @@ func resourceAwsSagemakerNotebookInstanceUpdate(d *schema.ResourceData, meta int
return fmt.Errorf("Error waiting for sagemaker notebook instance to start: %s", err)
}

stateConf = &resource.StateChangeConf{
Pending: []string{
sagemaker.NotebookInstanceStatusUpdating,
sagemaker.NotebookInstanceStatusPending,
sagemaker.NotebookInstanceStatusStopped,
},
Target: []string{sagemaker.NotebookInstanceStatusInService},
Refresh: sagemakerNotebookInstanceStateRefreshFunc(conn, d.Id()),
Timeout: 10 * time.Minute,
}
_, err = stateConf.WaitForState()
if err != nil {
return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to start after update: %s", d.Id(), err)
if _, err := waiter.NotebookInstanceInService(conn, d.Id()); err != nil {
return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to to start after update: %w", d.Id(), err)
}
}
}
Expand All @@ -389,7 +384,9 @@ func resourceAwsSagemakerNotebookInstanceDelete(d *schema.ResourceData, meta int
}
return fmt.Errorf("unable to find sagemaker notebook instance to delete (%s): %s", d.Id(), err)
}
if *notebook.NotebookInstanceStatus != sagemaker.NotebookInstanceStatusFailed && *notebook.NotebookInstanceStatus != sagemaker.NotebookInstanceStatusStopped {

if aws.StringValue(notebook.NotebookInstanceStatus) != sagemaker.NotebookInstanceStatusFailed &&
aws.StringValue(notebook.NotebookInstanceStatus) != sagemaker.NotebookInstanceStatusStopped {
if err := stopSagemakerNotebookInstance(conn, d.Id()); err != nil {
return err
}
Expand All @@ -403,17 +400,11 @@ func resourceAwsSagemakerNotebookInstanceDelete(d *schema.ResourceData, meta int
return fmt.Errorf("error trying to delete sagemaker notebook instance (%s): %s", d.Id(), err)
}

stateConf := &resource.StateChangeConf{
Pending: []string{
sagemaker.NotebookInstanceStatusDeleting,
},
Target: []string{""},
Refresh: sagemakerNotebookInstanceStateRefreshFunc(conn, d.Id()),
Timeout: 10 * time.Minute,
}
_, err = stateConf.WaitForState()
if err != nil {
return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to delete: %s", d.Id(), err)
if _, err := waiter.NotebookInstanceDeleted(conn, d.Id()); err != nil {
if isAWSErr(err, "ValidationException", "RecordNotFound") {
return nil
}
return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to delete: %w", d.Id(), err)
}

return nil
Expand All @@ -430,7 +421,7 @@ func stopSagemakerNotebookInstance(conn *sagemaker.SageMaker, id string) error {
}
return fmt.Errorf("unable to find sagemaker notebook instance (%s): %s", id, err)
}
if *notebook.NotebookInstanceStatus == sagemaker.NotebookInstanceStatusStopped {
if aws.StringValue(notebook.NotebookInstanceStatus) == sagemaker.NotebookInstanceStatusStopped {
return nil
}

Expand All @@ -442,17 +433,8 @@ func stopSagemakerNotebookInstance(conn *sagemaker.SageMaker, id string) error {
return fmt.Errorf("Error stopping sagemaker notebook instance: %s", err)
}

stateConf := &resource.StateChangeConf{
Pending: []string{
sagemaker.NotebookInstanceStatusStopping,
},
Target: []string{sagemaker.NotebookInstanceStatusStopped},
Refresh: sagemakerNotebookInstanceStateRefreshFunc(conn, id),
Timeout: 10 * time.Minute,
}
_, err = stateConf.WaitForState()
if err != nil {
return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to stop: %s", id, err)
if _, err := waiter.NotebookInstanceStopped(conn, id); err != nil {
return fmt.Errorf("error waiting for sagemaker notebook instance (%s) to stop: %w", id, err)
}

return nil
Expand Down
Loading

0 comments on commit 181ffb0

Please sign in to comment.