Skip to content

Commit

Permalink
Merge pull request #35873 from deepakbshetty/f-aws_sagemaker_model-mo…
Browse files Browse the repository at this point in the history
…del-access_config

Add model_access_config, multi_model_config and inference_specification_name to sagemaker_model primary_container and container block
  • Loading branch information
ewbankkit authored Sep 12, 2024
2 parents 1354c93 + 68baca3 commit 50a8f2c
Show file tree
Hide file tree
Showing 4 changed files with 424 additions and 3 deletions.
7 changes: 7 additions & 0 deletions .changelog/35873.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
```release-note:enhancement
resource/aws_sagemaker_model: Add `primary_container.model_data_source.s3_data_source.model_access_config`, `primary_container.multi_model_config`, `container.model_data_source.s3_data_source.model_access_config`, and ``container.multi_model_config` configuration blocks
```

```release-note:enhancement
resource/aws_sagemaker_model: Add `primary_container.inference_specification_name` and `container.inference_specification_name` arguments
```
155 changes: 155 additions & 0 deletions internal/service/sagemaker/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,48 @@ func resourceModel() *schema.Resource {
ForceNew: true,
ValidateDiagFunc: enum.Validate[awstypes.ModelCompressionType](),
},
"model_access_config": {
Type: schema.TypeList,
Optional: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"accept_eula": {
Type: schema.TypeBool,
Required: true,
ForceNew: true,
},
},
},
},
},
},
},
},
},
},
"inference_specification_name": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
ValidateFunc: validName,
},
"multi_model_config": {
Type: schema.TypeList,
Optional: true,
ForceNew: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"model_cache_setting": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
ValidateDiagFunc: enum.Validate[awstypes.ModelCacheSetting](),
},
},
},
},
},
},
},
Expand Down Expand Up @@ -294,12 +330,49 @@ func resourceModel() *schema.Resource {
ForceNew: true,
ValidateDiagFunc: enum.Validate[awstypes.ModelCompressionType](),
},
"model_access_config": {
Type: schema.TypeList,
Optional: true,
ForceNew: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"accept_eula": {
Type: schema.TypeBool,
Required: true,
ForceNew: true,
},
},
},
},
},
},
},
},
},
},
"inference_specification_name": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
ValidateFunc: validName,
},
"multi_model_config": {
Type: schema.TypeList,
Optional: true,
ForceNew: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"model_cache_setting": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
ValidateDiagFunc: enum.Validate[awstypes.ModelCacheSetting](),
},
},
},
},
},
},
},
Expand Down Expand Up @@ -551,6 +624,14 @@ func expandContainer(m map[string]interface{}) *awstypes.ContainerDefinition {
container.ImageConfig = expandModelImageConfig(v.([]interface{}))
}

if v, ok := m["inference_specification_name"]; ok && v.(string) != "" {
container.InferenceSpecificationName = aws.String(v.(string))
}

if v, ok := m["multi_model_config"].([]interface{}); ok && len(v) > 0 {
container.MultiModelConfig = expandMultiModelConfig(v)
}

return &container
}

Expand Down Expand Up @@ -589,6 +670,10 @@ func expandS3ModelDataSource(l []interface{}) *awstypes.S3ModelDataSource {
s3ModelDataSource.CompressionType = awstypes.ModelCompressionType(v.(string))
}

if v, ok := m["model_access_config"].([]interface{}); ok && len(v) > 0 {
s3ModelDataSource.ModelAccessConfig = expandModelAccessConfig(v)
}

return &s3ModelDataSource
}

Expand Down Expand Up @@ -634,6 +719,38 @@ func expandContainers(a []interface{}) []awstypes.ContainerDefinition {
return containers
}

func expandModelAccessConfig(l []interface{}) *awstypes.ModelAccessConfig {
if len(l) == 0 {
return nil
}

m := l[0].(map[string]interface{})

modelAccessConfig := &awstypes.ModelAccessConfig{}

if v, ok := m["accept_eula"].(bool); ok {
modelAccessConfig.AcceptEula = aws.Bool(v)
}

return modelAccessConfig
}

func expandMultiModelConfig(l []interface{}) *awstypes.MultiModelConfig {
if len(l) == 0 {
return nil
}

m := l[0].(map[string]interface{})

multiModelConfig := &awstypes.MultiModelConfig{}

if v, ok := m["model_cache_setting"].(string); ok && v != "" {
multiModelConfig.ModelCacheSetting = awstypes.ModelCacheSetting(v)
}

return multiModelConfig
}

func flattenContainer(container *awstypes.ContainerDefinition) []interface{} {
if container == nil {
return []interface{}{}
Expand Down Expand Up @@ -667,6 +784,14 @@ func flattenContainer(container *awstypes.ContainerDefinition) []interface{} {
cfg["image_config"] = flattenImageConfig(container.ImageConfig)
}

if container.InferenceSpecificationName != nil {
cfg["inference_specification_name"] = aws.ToString(container.InferenceSpecificationName)
}

if container.MultiModelConfig != nil {
cfg["multi_model_config"] = flattenMultiModelConfig(container.MultiModelConfig)
}

return []interface{}{cfg}
}

Expand Down Expand Up @@ -699,6 +824,10 @@ func flattenS3ModelDataSource(s3ModelDataSource *awstypes.S3ModelDataSource) []i

cfg["compression_type"] = s3ModelDataSource.CompressionType

if s3ModelDataSource.ModelAccessConfig != nil {
cfg["model_access_config"] = flattenModelAccessConfig(s3ModelDataSource.ModelAccessConfig)
}

return []interface{}{cfg}
}

Expand Down Expand Up @@ -740,6 +869,32 @@ func flattenContainers(containers []awstypes.ContainerDefinition) []interface{}
return fContainers
}

func flattenModelAccessConfig(config *awstypes.ModelAccessConfig) []interface{} {
if config == nil {
return []interface{}{}
}

cfg := make(map[string]interface{})

cfg["accept_eula"] = aws.ToBool(config.AcceptEula)

return []interface{}{cfg}
}

func flattenMultiModelConfig(config *awstypes.MultiModelConfig) []interface{} {
if config == nil {
return []interface{}{}
}

cfg := make(map[string]interface{})

if config.ModelCacheSetting != "" {
cfg["model_cache_setting"] = config.ModelCacheSetting
}

return []interface{}{cfg}
}

func expandModelInferenceExecutionConfig(l []interface{}) *awstypes.InferenceExecutionConfig {
if len(l) == 0 {
return nil
Expand Down
Loading

0 comments on commit 50a8f2c

Please sign in to comment.