Skip to content

Commit

Permalink
feat(GROW-2908): enable adding generic hcl block to root terraform block
Browse files Browse the repository at this point in the history
  • Loading branch information
ipcrm committed May 15, 2024
1 parent 8c76d48 commit 749e692
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 9 deletions.
16 changes: 13 additions & 3 deletions lwgenerate/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,9 @@ type GenerateAwsTfConfigurationArgs struct {

// Default AWS Provider Tags
ProviderDefaultTags map[string]interface{}

// Add custom blocks to the root `terraform{}` block. Can be used for advanced configuration. Things like backend, etc
ExtraBlocksRootTerraform []*hclwrite.Block
}

func (args *GenerateAwsTfConfigurationArgs) IsEmpty() bool {
Expand Down Expand Up @@ -732,6 +735,12 @@ func WithS3BucketNotification(s3BucketNotifiaction bool) AwsTerraformModifier {
}
}

func WithExtraRootBlocks(blocks []*hclwrite.Block) AwsTerraformModifier {
return func(c *GenerateAwsTfConfigurationArgs) {
c.ExtraBlocksRootTerraform = blocks
}
}

// Generate new Terraform code based on the supplied args.
func (args *GenerateAwsTfConfigurationArgs) Generate() (string, error) {
// Validate inputs
Expand All @@ -740,7 +749,7 @@ func (args *GenerateAwsTfConfigurationArgs) Generate() (string, error) {
}

// Create blocks
requiredProviders, err := createRequiredProviders()
requiredProviders, err := createRequiredProviders(args.ExtraBlocksRootTerraform)
if err != nil {
return "", errors.Wrap(err, "failed to generate required providers")
}
Expand Down Expand Up @@ -793,8 +802,9 @@ func (args *GenerateAwsTfConfigurationArgs) Generate() (string, error) {
return hclBlocks, nil
}

func createRequiredProviders() (*hclwrite.Block, error) {
return lwgenerate.CreateRequiredProviders(
func createRequiredProviders(extraBlocks []*hclwrite.Block) (*hclwrite.Block, error) {
return lwgenerate.CreateRequiredProvidersWithCustomBlocks(
extraBlocks,
lwgenerate.NewRequiredProvider("lacework",
lwgenerate.HclRequiredProviderWithSource(lwgenerate.LaceworkProviderSource),
lwgenerate.HclRequiredProviderWithVersion(lwgenerate.LaceworkProviderVersion)))
Expand Down
23 changes: 23 additions & 0 deletions lwgenerate/aws/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"testing"

"github.com/hashicorp/hcl/v2/hclwrite"
"github.com/lacework/go-sdk/lwgenerate"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -91,6 +92,16 @@ func TestGenerationConfig(t *testing.T) {
assert.Equal(t, reqProviderAndRegion(moduleImportConfig), hcl)
}

func TestGenerationConfigWithCustomBackendBlock(t *testing.T) {
customBlock, err := lwgenerate.HclCreateGenericBlock("backend", []string{"s3"}, nil)
assert.NoError(t, err)
hcl, err := NewTerraform(false, false, true, false, WithAwsRegion("us-east-2"),
WithExtraRootBlocks([]*hclwrite.Block{customBlock})).Generate()
assert.Nil(t, err)
assert.NotNil(t, hcl)
assert.Equal(t, requiredProvidersWithCustomBlock+"\n"+awsProvider+"\n"+moduleImportConfig, hcl)
}

func TestGenerationConfigWithOutputs(t *testing.T) {
hcl, err := NewTerraform(
false, false, true, false, WithAwsRegion("us-east-2"),
Expand Down Expand Up @@ -390,6 +401,18 @@ func TestGenerationCloudTrailS3BucketNotification(t *testing.T) {
)
}

var requiredProvidersWithCustomBlock = `terraform {
required_providers {
lacework = {
source = "lacework/lacework"
version = "~> 1.0"
}
}
backend "s3" {
}
}
`

var requiredProviders = `terraform {
required_providers {
lacework = {
Expand Down
50 changes: 44 additions & 6 deletions lwgenerate/hcl.go
Original file line number Diff line number Diff line change
Expand Up @@ -556,13 +556,13 @@ func CreateHclStringOutput(blocks []*hclwrite.Block) string {
return string(file.Bytes())
}

// CreateRequiredProviders Create required providers block
func CreateRequiredProviders(providers ...*HclRequiredProvider) (*hclwrite.Block, error) {
block, err := HclCreateGenericBlock("terraform", nil, nil)
if err != nil {
return nil, err
}
// rootTerraformBlock is a helper that creates the literal `terraform{}` hcl block
func rootTerraformBlock() (*hclwrite.Block, error) {
return HclCreateGenericBlock("terraform", nil, nil)
}

// createRequiredProviders is a helper that creates the `required_providers` hcl block
func createRequiredProviders(providers ...*HclRequiredProvider) (*hclwrite.Block, error) {
providerDetails := map[string]interface{}{}
for _, provider := range providers {
details := map[string]interface{}{}
Expand All @@ -579,7 +579,45 @@ func CreateRequiredProviders(providers ...*HclRequiredProvider) (*hclwrite.Block
if err != nil {
return nil, err
}

return requiredProviders, nil
}

// CreateRequiredProviders Create required providers block
func CreateRequiredProviders(providers ...*HclRequiredProvider) (*hclwrite.Block, error) {
block, err := rootTerraformBlock()
if err != nil {
return nil, err
}

requiredProviders, err := createRequiredProviders(providers...)
if err != nil {
return nil, err
}

block.Body().AppendBlock(requiredProviders)
return block, nil
}

// CreateRequiredProviders Create required providers block
func CreateRequiredProvidersWithCustomBlocks(
blocks []*hclwrite.Block,
providers ...*HclRequiredProvider,
) (*hclwrite.Block, error) {
block, err := rootTerraformBlock()
if err != nil {
return nil, err
}

requiredProviders, err := createRequiredProviders(providers...)
if err != nil {
return nil, err
}

block.Body().AppendBlock(requiredProviders)
for _, customBlock := range blocks {
block.Body().AppendBlock(customBlock)
}

return block, nil
}
Expand Down
34 changes: 34 additions & 0 deletions lwgenerate/hcl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,22 @@ func TestRequiredProvidersBlock(t *testing.T) {
assert.Equal(t, testRequiredProvider, lwgenerate.CreateHclStringOutput([]*hclwrite.Block{data}))
}

func TestRequiredProvidersBlockWithCustomBlocks(t *testing.T) {
provider1 := lwgenerate.NewRequiredProvider("foo",
lwgenerate.HclRequiredProviderWithSource("test/test"))
provider2 := lwgenerate.NewRequiredProvider("bar",
lwgenerate.HclRequiredProviderWithVersion("~> 0.1"))
provider3 := lwgenerate.NewRequiredProvider("lacework",
lwgenerate.HclRequiredProviderWithSource("lacework/lacework"),
lwgenerate.HclRequiredProviderWithVersion("~> 0.1"))

customBlock, err := lwgenerate.HclCreateGenericBlock("backend", []string{"s3"}, nil)
assert.NoError(t, err)
data, err := lwgenerate.CreateRequiredProvidersWithCustomBlocks([]*hclwrite.Block{customBlock}, provider1, provider2, provider3)
assert.Nil(t, err)
assert.Equal(t, testRequiredProviderWithCustomBlocks, lwgenerate.CreateHclStringOutput([]*hclwrite.Block{data}))
}

func TestModuleBlockWithComplexAttributes(t *testing.T) {
data, err := lwgenerate.NewModule("foo",
"mycorp/mycloud",
Expand Down Expand Up @@ -192,6 +208,24 @@ func TestOutputBlockCreation(t *testing.T) {
})
}

var testRequiredProviderWithCustomBlocks = `terraform {
required_providers {
bar = {
version = "~> 0.1"
}
foo = {
source = "test/test"
}
lacework = {
source = "lacework/lacework"
version = "~> 0.1"
}
}
backend "s3" {
}
}
`

var testRequiredProvider = `terraform {
required_providers {
bar = {
Expand Down

0 comments on commit 749e692

Please sign in to comment.