diff --git a/pkg/iac/scanners/terraform/parser/evaluator.go b/pkg/iac/scanners/terraform/parser/evaluator.go index 1dea2e9473a3..a203451245ff 100644 --- a/pkg/iac/scanners/terraform/parser/evaluator.go +++ b/pkg/iac/scanners/terraform/parser/evaluator.go @@ -3,7 +3,6 @@ package parser import ( "context" "errors" - "fmt" "io/fs" "reflect" "time" @@ -148,7 +147,8 @@ func (e *evaluator) EvaluateAll(ctx context.Context) (terraform.Modules, map[str } } - // expand out resources and modules via count (not a typo, we do this twice so every order is processed) + // expand out resources and modules via count, for-each and dynamic + // (not a typo, we do this twice so every order is processed) e.blocks = e.expandBlocks(e.blocks) e.blocks = e.expandBlocks(e.blocks) @@ -204,7 +204,7 @@ func (e *evaluator) isModuleLocal() bool { } func (e *evaluator) expandBlocks(blocks terraform.Blocks) terraform.Blocks { - return e.expandDynamicBlocks(e.expandBlockForEaches(e.expandBlockCounts(blocks))...) + return e.expandDynamicBlocks(e.expandBlockForEaches(e.expandBlockCounts(blocks), false)...) } func (e *evaluator) expandDynamicBlocks(blocks ...*terraform.Block) terraform.Blocks { @@ -219,80 +219,49 @@ func (e *evaluator) expandDynamicBlock(b *terraform.Block) { e.expandDynamicBlock(sub) } for _, sub := range b.AllBlocks().OfType("dynamic") { + if sub.IsExpanded() { + continue + } blockName := sub.TypeLabel() - expanded := e.expandBlockForEaches(terraform.Blocks{sub}) + expanded := e.expandBlockForEaches(terraform.Blocks{sub}, true) for _, ex := range expanded { if content := ex.GetBlock("content"); content.IsNotNil() { _ = e.expandDynamicBlocks(content) b.InjectBlock(content, blockName) } } + sub.MarkExpanded() } } -func validateForEachArg(arg cty.Value) error { - if arg.IsNull() { - return errors.New("arg is null") - } - - ty := arg.Type() - - if !arg.IsKnown() || ty.Equals(cty.DynamicPseudoType) || arg.LengthInt() == 0 { - return nil - } - - if !(ty.IsSetType() || ty.IsObjectType() || ty.IsMapType()) { - return fmt.Errorf("%s type is not supported: arg is not set or map", ty.FriendlyName()) - } - - if ty.IsSetType() { - if !ty.ElementType().Equals(cty.String) { - return errors.New("arg is not set of strings") - } - - it := arg.ElementIterator() - for it.Next() { - key, _ := it.Element() - if key.IsNull() { - return errors.New("arg is set of strings, but contains null") - } - - if !key.IsKnown() { - return errors.New("arg is set of strings, but contains unknown value") - } - } - } - - return nil -} - func isBlockSupportsForEachMetaArgument(block *terraform.Block) bool { return slices.Contains([]string{"module", "resource", "data", "dynamic"}, block.Type()) } -func (e *evaluator) expandBlockForEaches(blocks terraform.Blocks) terraform.Blocks { +func (e *evaluator) expandBlockForEaches(blocks terraform.Blocks, isDynamic bool) terraform.Blocks { var forEachFiltered terraform.Blocks for _, block := range blocks { forEachAttr := block.GetAttribute("for_each") - if forEachAttr.IsNil() || block.IsCountExpanded() || !isBlockSupportsForEachMetaArgument(block) { + if forEachAttr.IsNil() || block.IsExpanded() || !isBlockSupportsForEachMetaArgument(block) { forEachFiltered = append(forEachFiltered, block) continue } forEachVal := forEachAttr.Value() - if err := validateForEachArg(forEachVal); err != nil { - e.debug.Log(`"for_each" argument is invalid: %s`, err.Error()) + if forEachVal.IsNull() || !forEachVal.IsKnown() || !forEachAttr.IsIterable() { continue } clones := make(map[string]cty.Value) _ = forEachAttr.Each(func(key cty.Value, val cty.Value) { - if !key.Type().Equals(cty.String) { + // instances are identified by a map key (or set member) from the value provided to for_each + idx, err := convert.Convert(key, cty.String) + if err != nil { e.debug.Log( `Invalid "for-each" argument: map key (or set value) is not a string, but %s`, key.Type().FriendlyName(), @@ -300,22 +269,34 @@ func (e *evaluator) expandBlockForEaches(blocks terraform.Blocks) terraform.Bloc return } - clone := block.Clone(key) + // if the argument is a collection but not a map, then the resource identifier + // is the value of the collection. The exception is the use of for-each inside a dynamic block, + // because in this case the collection element may not be a primitive value. + if (forEachVal.Type().IsCollectionType() || forEachVal.Type().IsTupleType()) && + !forEachVal.Type().IsMapType() && !isDynamic { + stringVal, err := convert.Convert(val, cty.String) + if err != nil { + e.debug.Log("Failed to convert for-each arg %v to string", val) + return + } + idx = stringVal + } + + clone := block.Clone(idx) ctx := clone.Context() e.copyVariables(block, clone) - ctx.SetByDot(key, "each.key") + ctx.SetByDot(idx, "each.key") ctx.SetByDot(val, "each.value") - - ctx.Set(key, block.TypeLabel(), "key") + ctx.Set(idx, block.TypeLabel(), "key") ctx.Set(val, block.TypeLabel(), "value") forEachFiltered = append(forEachFiltered, clone) values := clone.Values() - clones[key.AsString()] = values + clones[idx.AsString()] = values e.ctx.SetByDot(values, clone.GetMetadata().Reference()) }) @@ -341,7 +322,7 @@ func (e *evaluator) expandBlockCounts(blocks terraform.Blocks) terraform.Blocks var countFiltered terraform.Blocks for _, block := range blocks { countAttr := block.GetAttribute("count") - if countAttr.IsNil() || block.IsCountExpanded() || !isBlockSupportsCountMetaArgument(block) { + if countAttr.IsNil() || block.IsExpanded() || !isBlockSupportsCountMetaArgument(block) { countFiltered = append(countFiltered, block) continue } diff --git a/pkg/iac/scanners/terraform/parser/evaluator_test.go b/pkg/iac/scanners/terraform/parser/evaluator_test.go deleted file mode 100644 index 8d3ef7b0f6e0..000000000000 --- a/pkg/iac/scanners/terraform/parser/evaluator_test.go +++ /dev/null @@ -1,94 +0,0 @@ -package parser - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/zclconf/go-cty/cty" -) - -func TestValidateForEachArg(t *testing.T) { - tests := []struct { - name string - arg cty.Value - expectedError string - }{ - { - name: "empty set", - arg: cty.SetValEmpty(cty.String), - }, - { - name: "set of strings", - arg: cty.SetVal([]cty.Value{cty.StringVal("val1"), cty.StringVal("val2")}), - }, - { - name: "set of non-strings", - arg: cty.SetVal([]cty.Value{cty.NumberIntVal(1), cty.NumberIntVal(2)}), - expectedError: "is not set of strings", - }, - { - name: "set with null", - arg: cty.SetVal([]cty.Value{cty.StringVal("val1"), cty.NullVal(cty.String)}), - expectedError: "arg is set of strings, but contains null", - }, - { - name: "set with unknown", - arg: cty.SetVal([]cty.Value{cty.StringVal("val1"), cty.UnknownVal(cty.String)}), - expectedError: "arg is set of strings, but contains unknown", - }, - { - name: "set with unknown", - arg: cty.SetVal([]cty.Value{cty.StringVal("val1"), cty.UnknownVal(cty.String)}), - expectedError: "arg is set of strings, but contains unknown", - }, - { - name: "non empty map", - arg: cty.MapVal(map[string]cty.Value{ - "val1": cty.StringVal("..."), - "val2": cty.StringVal("..."), - }), - }, - { - name: "map with unknown", - arg: cty.MapVal(map[string]cty.Value{ - "val1": cty.UnknownVal(cty.String), - "val2": cty.StringVal("..."), - }), - }, - { - name: "empty obj", - arg: cty.EmptyObjectVal, - }, - { - name: "obj with strings", - arg: cty.ObjectVal(map[string]cty.Value{ - "val1": cty.StringVal("..."), - "val2": cty.StringVal("..."), - }), - }, - { - name: "null", - arg: cty.NullVal(cty.Set(cty.String)), - expectedError: "arg is null", - }, - { - name: "unknown", - arg: cty.UnknownVal(cty.Set(cty.String)), - }, - { - name: "dynamic", - arg: cty.DynamicVal, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := validateForEachArg(tt.arg) - if tt.expectedError != "" && err != nil { - assert.ErrorContains(t, err, tt.expectedError) - return - } - assert.NoError(t, err) - }) - } -} diff --git a/pkg/iac/scanners/terraform/parser/parser_test.go b/pkg/iac/scanners/terraform/parser/parser_test.go index 926d9e56603d..c8c6d727c112 100644 --- a/pkg/iac/scanners/terraform/parser/parser_test.go +++ b/pkg/iac/scanners/terraform/parser/parser_test.go @@ -8,6 +8,7 @@ import ( "github.com/aquasecurity/trivy/internal/testutil" "github.com/aquasecurity/trivy/pkg/iac/scanners/options" + "github.com/aquasecurity/trivy/pkg/iac/terraform" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/zclconf/go-cty/cty" @@ -904,6 +905,91 @@ data "http" "example" { } func TestForEach(t *testing.T) { + tests := []struct { + name string + src string + expectedBucketName string + expectedNameLabel string + }{ + { + name: "arg is set and ref to each.key", + src: `locals { + buckets = ["bucket1"] +} + +resource "aws_s3_bucket" "this" { + for_each = toset(local.buckets) + bucket = each.key +}`, + expectedBucketName: "bucket1", + expectedNameLabel: `this["bucket1"]`, + }, + { + name: "arg is set and ref to each.value", + src: `locals { + buckets = ["bucket1"] +} + +resource "aws_s3_bucket" "this" { + for_each = toset(local.buckets) + bucket = each.value +}`, + expectedBucketName: "bucket1", + expectedNameLabel: `this["bucket1"]`, + }, + { + name: "arg is map and ref to each.key", + src: `locals { + buckets = { + bucket1key = "bucket1value" + } +} + +resource "aws_s3_bucket" "this" { + for_each = local.buckets + bucket = each.key +}`, + expectedBucketName: "bucket1key", + expectedNameLabel: `this["bucket1key"]`, + }, + { + name: "arg is map and ref to each.value", + src: `locals { + buckets = { + bucket1key = "bucket1value" + } +} + +resource "aws_s3_bucket" "this" { + for_each = local.buckets + bucket = each.value +}`, + expectedBucketName: "bucket1value", + expectedNameLabel: `this["bucket1key"]`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + modules := parse(t, map[string]string{ + "main.tf": tt.src, + }) + require.Len(t, modules, 1) + + buckets := modules.GetResourcesByType("aws_s3_bucket") + assert.Len(t, buckets, 1) + + bucket := buckets[0] + bucketName := bucket.GetAttribute("bucket").Value().AsString() + assert.Equal(t, tt.expectedBucketName, bucketName) + + assert.Equal(t, tt.expectedNameLabel, bucket.NameLabel()) + }) + } + +} + +func TestForEachCountExpanded(t *testing.T) { tests := []struct { name string @@ -919,6 +1005,18 @@ func TestForEach(t *testing.T) { resource "aws_s3_bucket" "this" { for_each = local.buckets bucket = each.key +}`, + expectedCount: 2, + }, + { + name: "arg is empty list", + source: `locals { + buckets = [] +} + +resource "aws_s3_bucket" "this" { + for_each = local.buckets + bucket = each.value }`, expectedCount: 0, }, @@ -929,8 +1027,34 @@ resource "aws_s3_bucket" "this" { } resource "aws_s3_bucket" "this" { - for_each = loca.buckets + for_each = local.buckets + bucket = each.key +}`, + expectedCount: 0, + }, + { + name: "argument set with the same values", + source: `locals { + buckets = ["true", "true"] +} + +resource "aws_s3_bucket" "this" { + for_each = toset(local.buckets) bucket = each.key +}`, + expectedCount: 1, + }, + { + name: "arg is non-valid set", + source: `locals { + buckets = [{ + bucket1key = "bucket1value" + }] +} + +resource "aws_s3_bucket" "this" { + for_each = toset(local.buckets) + bucket = each.value }`, expectedCount: 0, }, @@ -961,18 +1085,25 @@ resource "aws_s3_bucket" "this" { }`, expectedCount: 2, }, + { + name: "arg is empty map", + source: `locals { + buckets = {} +} +resource "aws_s3_bucket" "this" { + for_each = local.buckets + bucket = each.value +} + `, + expectedCount: 0, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - fs := testutil.CreateFS(t, map[string]string{ + modules := parse(t, map[string]string{ "main.tf": tt.source, }) - parser := New(fs, "", OptionStopOnHCLError(true)) - require.NoError(t, parser.ParseFS(context.TODO(), ".")) - - modules, _, err := parser.EvaluateAll(context.TODO()) - assert.NoError(t, err) assert.Len(t, modules, 1) bucketBlocks := modules.GetResourcesByType("aws_s3_bucket") @@ -1139,3 +1270,165 @@ func TestForEachWithObjectsOfDifferentTypes(t *testing.T) { assert.NoError(t, err) assert.Len(t, modules, 1) } + +func TestDynamicBlocks(t *testing.T) { + t.Run("arg is list of int", func(t *testing.T) { + modules := parse(t, map[string]string{ + "main.tf": ` +resource "aws_security_group" "sg-webserver" { + vpc_id = "1111" + dynamic "ingress" { + for_each = [80, 443] + content { + from_port = ingress.value + to_port = ingress.value + protocol = "tcp" + cidr_blocks = ["0.0.0.0/0"] + } + } +} +`, + }) + require.Len(t, modules, 1) + + secGroups := modules.GetResourcesByType("aws_security_group") + assert.Len(t, secGroups, 1) + ingressBlocks := secGroups[0].GetBlocks("ingress") + assert.Len(t, ingressBlocks, 2) + + var inboundPorts []int + for _, ingress := range ingressBlocks { + fromPort := ingress.GetAttribute("from_port").AsIntValueOrDefault(-1, ingress).Value() + inboundPorts = append(inboundPorts, fromPort) + } + + assert.True(t, compareSets([]int{80, 443}, inboundPorts)) + }) + + t.Run("empty for-each", func(t *testing.T) { + modules := parse(t, map[string]string{ + "main.tf": ` +resource "aws_lambda_function" "analyzer" { + dynamic "vpc_config" { + for_each = [] + content {} + } +} +`, + }) + require.Len(t, modules, 1) + + functions := modules.GetResourcesByType("aws_lambda_function") + assert.Len(t, functions, 1) + vpcConfigs := functions[0].GetBlocks("vpc_config") + assert.Empty(t, vpcConfigs) + }) + + t.Run("arg is list of bool", func(t *testing.T) { + modules := parse(t, map[string]string{ + "main.tf": ` +resource "aws_lambda_function" "analyzer" { + dynamic "vpc_config" { + for_each = [true] + content {} + } +} +`, + }) + require.Len(t, modules, 1) + + functions := modules.GetResourcesByType("aws_lambda_function") + assert.Len(t, functions, 1) + vpcConfigs := functions[0].GetBlocks("vpc_config") + assert.Len(t, vpcConfigs, 1) + }) + + t.Run("arg is list of objects", func(t *testing.T) { + modules := parse(t, map[string]string{ + "main.tf": `locals { + cluster_network_policy = [{ + enabled = true + }] +} + +resource "google_container_cluster" "primary" { + name = "test" + + dynamic "network_policy" { + for_each = local.cluster_network_policy + + content { + enabled = network_policy.value.enabled + } + } +}`, + }) + require.Len(t, modules, 1) + + clusters := modules.GetResourcesByType("google_container_cluster") + assert.Len(t, clusters, 1) + + networkPolicies := clusters[0].GetBlocks("network_policy") + assert.Len(t, networkPolicies, 1) + + enabled := networkPolicies[0].GetAttribute("enabled") + assert.True(t, enabled.Value().True()) + }) + + t.Run("nested dynamic", func(t *testing.T) { + modules := parse(t, map[string]string{ + "main.tf": ` +resource "test_block" "this" { + name = "name" + location = "loc" + dynamic "env" { + for_each = ["1", "2"] + content { + dynamic "value_source" { + for_each = [true, true] + content {} + } + } + } +}`, + }) + require.Len(t, modules, 1) + + testResources := modules.GetResourcesByType("test_block") + assert.Len(t, testResources, 1) + envs := testResources[0].GetBlocks("env") + assert.Len(t, envs, 2) + + var sources []*terraform.Block + for _, env := range envs { + sources = append(sources, env.GetBlocks("value_source")...) + } + assert.Len(t, sources, 4) + }) +} + +func parse(t *testing.T, files map[string]string) terraform.Modules { + fs := testutil.CreateFS(t, files) + parser := New(fs, "", OptionStopOnHCLError(true)) + require.NoError(t, parser.ParseFS(context.TODO(), ".")) + + modules, _, err := parser.EvaluateAll(context.TODO()) + require.NoError(t, err) + + return modules +} + +func compareSets(a []int, b []int) bool { + m := make(map[int]bool) + for _, el := range a { + m[el] = true + } + + for _, el := range b { + if !m[el] { + return false + } + } + + return true +} diff --git a/pkg/iac/terraform/block.go b/pkg/iac/terraform/block.go index bfab10bf316c..6807fddd0f7d 100644 --- a/pkg/iac/terraform/block.go +++ b/pkg/iac/terraform/block.go @@ -145,11 +145,11 @@ func (b *Block) InjectBlock(block *Block, name string) { b.childBlocks = append(b.childBlocks, block) } -func (b *Block) MarkCountExpanded() { +func (b *Block) MarkExpanded() { b.expanded = true } -func (b *Block) IsCountExpanded() bool { +func (b *Block) IsExpanded() bool { return b.expanded } @@ -187,7 +187,7 @@ func (b *Block) Clone(index cty.Value) *Block { } indexVal, _ := gocty.ToCtyValue(index, cty.Number) clone.context.SetByDot(indexVal, "count.index") - clone.MarkCountExpanded() + clone.MarkExpanded() b.cloneIndex++ return clone }