diff --git a/tfupdate/hclwrite.go b/tfupdate/hclwrite.go index 6740bb7..50ecf3d 100644 --- a/tfupdate/hclwrite.go +++ b/tfupdate/hclwrite.go @@ -5,9 +5,8 @@ import ( "reflect" "github.com/hashicorp/hcl/v2" - "github.com/hashicorp/hcl/v2/hclsyntax" + "github.com/hashicorp/hcl/v2/hclparse" "github.com/hashicorp/hcl/v2/hclwrite" - "github.com/zclconf/go-cty/cty" ) // allMatchingBlocks returns all matching blocks from the body that have the @@ -45,27 +44,40 @@ func allMatchingBlocksByType(b *hclwrite.Body, typeName string) []*hclwrite.Bloc return matched } -// getAttributeValue extracts cty.Value from hclwrite.Attribute. +// getHCLNativeAttribute gets hclwrite.Attribute as a native hcl.Attribute. // At the time of writing, there is no way to do with the hclwrite AST, -// so we build low-level byte sequences and parse an expression as a -// hclsyntax.Expression on memory. -func getAttributeValue(attr *hclwrite.Attribute) (cty.Value, error) { +// so we build low-level byte sequences and parse an attribute as a +// hcl.Attribute on memory. +// If not found, returns nil without an error. +func getHCLNativeAttribute(body *hclwrite.Body, name string) (*hcl.Attribute, error) { + attr := body.GetAttribute(name) + if attr == nil { + return nil, nil + } + // build low-level byte sequences - src := attr.Expr().BuildTokens(nil).Bytes() + attrAsBytes := attr.Expr().BuildTokens(nil).Bytes() + src := append([]byte(name+" = "), attrAsBytes...) - // parse an expression as a hclsyntax.Expression - expr, diags := hclsyntax.ParseExpression(src, "generated_by_attributeToValue", hcl.Pos{Line: 1, Column: 1}) + // parse an expression as a hcl.File. + // Note that an attribute may contains references, which are defined outside the file. + // So we cannot simply use hclsyntax.ParseExpression or hclsyntax.ParseConfig here. + // We need to use a loe-level parser not to resolve all references. + parser := hclparse.NewParser() + file, diags := parser.ParseHCL(src, "generated_by_getHCLNativeAttribute") if diags.HasErrors() { - return cty.NilVal, fmt.Errorf("failed to parse expression: %s", diags) + return nil, fmt.Errorf("failed to parse expression: %s", diags) } - // Get value from expression. - // We don't need interpolation for any variables and functions here, - // so we just pass an empty context. - v, diags := expr.Value(&hcl.EvalContext{}) + attrs, diags := file.Body.JustAttributes() if diags.HasErrors() { - return cty.NilVal, fmt.Errorf("failed to get cty.Value: %s", diags) + return nil, fmt.Errorf("failed to get attributes: %s", diags) + } + + hclAttr, ok := attrs[name] + if !ok { + return nil, fmt.Errorf("attribute not found: %s", src) } - return v, nil + return hclAttr, nil } diff --git a/tfupdate/hclwrite_test.go b/tfupdate/hclwrite_test.go index 75d5e1a..959dea5 100644 --- a/tfupdate/hclwrite_test.go +++ b/tfupdate/hclwrite_test.go @@ -2,12 +2,13 @@ package tfupdate import ( "fmt" + "reflect" "strings" "testing" "github.com/hashicorp/hcl/v2" + "github.com/hashicorp/hcl/v2/hclsyntax" "github.com/hashicorp/hcl/v2/hclwrite" - "github.com/zclconf/go-cty/cty" ) func TestAllMatchingBlocks(t *testing.T) { @@ -159,43 +160,88 @@ service "label1" { } } -func TestGetAttributeValue(t *testing.T) { - tests := []struct { - valueAsString string - want cty.Value - ok bool +func TestGetHCLNativeAttributeValue(t *testing.T) { + cases := []struct { + desc string + src string + name string + wantExprType hcl.Expression + ok bool }{ { - want: cty.StringVal("FOO"), - ok: true, + desc: "string literal", + src: ` +foo = "123" +`, + name: "foo", + wantExprType: &hclsyntax.TemplateExpr{}, + ok: true, }, { - want: cty.ObjectVal(map[string]cty.Value{ - "foo": cty.StringVal("FOO"), - "bar": cty.StringVal("BAR"), - }), - ok: true, + desc: "object literal", + src: ` +foo = { + bar = "123" + baz = "BAZ" +} +`, + name: "foo", + wantExprType: &hclsyntax.ObjectConsExpr{}, + ok: true, + }, + { + desc: "object with references", + src: ` +foo = { + bar = "123" + baz = "BAZ" + + items = [ + var.aaa, + var.bbb, + ] +} +`, + name: "foo", + wantExprType: &hclsyntax.ObjectConsExpr{}, + ok: true, + }, + { + desc: "not found", + src: ` +foo = "123" +`, + name: "bar", + wantExprType: nil, + ok: true, }, } - for _, test := range tests { - t.Run(fmt.Sprintf("%s", test.valueAsString), func(t *testing.T) { - // build hclwrite.Attribute - f := hclwrite.NewEmptyFile() - f.Body().SetAttributeValue("test", test.want) - attr := f.Body().GetAttribute("test") - - got, err := getAttributeValue(attr) - if test.ok && err != nil { - t.Errorf("getAttributeValue() with attr = %s returns unexpected err: %+v", test.want, err) + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + f, diags := hclwrite.ParseConfig([]byte(tc.src), "", hcl.Pos{Line: 1, Column: 1}) + if len(diags) != 0 { + for _, diag := range diags { + t.Logf("- %s", diag.Error()) + } + t.Fatalf("unexpected diagnostics") } - if !test.ok && err == nil { - t.Errorf("getAttributeValue() with attr = %s expects to return an error, but no error", test.want) + got, err := getHCLNativeAttribute(f.Body(), tc.name) + if tc.ok && err != nil { + t.Errorf("unexpected err: %#v", err) } - if !got.RawEquals(test.want) { - t.Errorf("getAttributeValue() with attr = %s returns %#v, but want = %#v", test.want, got, test.want) + if !tc.ok && err == nil { + t.Errorf("expects to return an error, but no error. got = %#v", got) + } + + if tc.ok && got != nil { + // An expression is a complicated object and hard to build from literal. + // So we simply compare it by type. + if reflect.TypeOf(got.Expr) != reflect.TypeOf(tc.wantExprType) { + t.Errorf("got = %#v, but want = %#v", got.Expr, tc.wantExprType) + } } }) } diff --git a/tfupdate/provider.go b/tfupdate/provider.go index 1a05ff7..601c73d 100644 --- a/tfupdate/provider.go +++ b/tfupdate/provider.go @@ -1,6 +1,9 @@ package tfupdate import ( + "fmt" + + "github.com/hashicorp/hcl/v2" "github.com/hashicorp/hcl/v2/hclsyntax" "github.com/hashicorp/hcl/v2/hclwrite" "github.com/pkg/errors" @@ -50,24 +53,25 @@ func (u *ProviderUpdater) updateTerraformBlock(f *hclwrite.File) error { continue } - attr := p.Body().GetAttribute(u.name) - if attr != nil { - value, err := getAttributeValue(attr) - if err != nil { - return err - } + // The hclwrite.Attribute doesn't have enough AST for object type to check. + // Get the attribute as a native hcl.Attribute as a compromise. + hclAttr, err := getHCLNativeAttribute(p.Body(), u.name) + if err != nil { + return err + } + if hclAttr != nil { // There are some variations on the syntax of required_providers. - // So we check a type of value and switch implementations. - switch { - case value.Type().IsObjectType(): - u.updateTerraformRequiredProvidersBlockAsObject(p, value) - - case value.Type() == cty.String: + // So we check a type of the value and switch implementations. + // If the expression can be parsed as a static expression and it's type is a primitive, + // then it's a legacy string syntax. + if expr, err := hclAttr.Expr.Value(nil); err == nil && expr.Type().IsPrimitiveType() { u.updateTerraformRequiredProvidersBlockAsString(p) - - default: - return errors.Errorf("failed to update required_providers. unknown type: %#v", value) + } else { + // Otherwise, it's an object syntax. + if err := u.updateTerraformRequiredProvidersBlockAsObject(p, hclAttr); err != nil { + return err + } } } } @@ -75,28 +79,35 @@ func (u *ProviderUpdater) updateTerraformBlock(f *hclwrite.File) error { return nil } -func (u *ProviderUpdater) updateTerraformRequiredProvidersBlockAsObject(p *hclwrite.Block, value cty.Value) { +func (u *ProviderUpdater) updateTerraformRequiredProvidersBlockAsObject(p *hclwrite.Block, hclAttr *hcl.Attribute) error { // terraform { // required_providers { // aws = { // source = "hashicorp/aws" // version = "2.65.0" + // + // configuration_aliases = [ + // aws.primary, + // aws.secondary, + // ] // } // } // } - m := value.AsValueMap() - if _, ok := m["version"]; !ok { + + oldVersion, err := detectVersionInObject(hclAttr) + if err != nil { + return err + } + + if len(oldVersion) == 0 { // If the version key is missing, just ignore it. - return + return nil } // Updating the whole object loses original sort order and comments. // At the time of writing, there is no way to update a value inside an // object directly while preserving original tokens. // - // m["version"] = cty.StringVal(u.version) - // p.Body().SetAttributeValue(u.name, cty.ObjectVal(m)) - // // Since we fully understand the valid syntax, we compromise and read the // tokens in order, updating the bytes directly. // It's apparently a fragile dirty hack, but I didn't come up with the better @@ -116,7 +127,6 @@ func (u *ProviderUpdater) updateTerraformRequiredProvidersBlockAsObject(p *hclwr } // find value of old version - oldVersion := m["version"].AsString() for !(tokens[i].Type == hclsyntax.TokenQuotedLit && string(tokens[i].Bytes) == oldVersion) { i++ } @@ -126,7 +136,37 @@ func (u *ProviderUpdater) updateTerraformRequiredProvidersBlockAsObject(p *hclwr // So we now update bytes of the token in place. tokens[i].Bytes = []byte(u.version) - return + return nil +} + +// detectVersionInObject parses an object expression and detects a value for +// the "version" key. +// If the version key is missing, just returns an empty string without an error. +func detectVersionInObject(hclAttr *hcl.Attribute) (string, error) { + // The configuration_aliases syntax isn't directly related version updateing, + // but it contains provider references and causes an parse error without an EvalContext. + // So we treat the expression as a hcl.ExprMap to avoid fully decoding the object. + kvs, diags := hcl.ExprMap(hclAttr.Expr) + if diags.HasErrors() { + return "", fmt.Errorf("failed to parse expr as hcl.ExprMap: %s", diags) + } + + oldVersion := "" + for _, kv := range kvs { + key, diags := kv.Key.Value(nil) + if diags.HasErrors() { + return "", fmt.Errorf("failed to get key: %s", diags) + } + if key.AsString() == "version" { + value, diags := kv.Value.Value(nil) + if diags.HasErrors() { + return "", fmt.Errorf("failed to get value: %s", diags) + } + oldVersion = value.AsString() + } + } + + return oldVersion, nil } func (u *ProviderUpdater) updateTerraformRequiredProvidersBlockAsString(p *hclwrite.Block) { diff --git a/tfupdate/provider_test.go b/tfupdate/provider_test.go index a94fd53..a52dbf0 100644 --- a/tfupdate/provider_test.go +++ b/tfupdate/provider_test.go @@ -315,6 +315,41 @@ terraform { } } } +`, + ok: true, + }, + { + src: ` +terraform { + required_providers { + aws = { + version = "2.65.0" + source = "hashicorp/aws" + + configuration_aliases = [ + aws.primary, + aws.secondary, + ] + } + } +} +`, + name: "aws", + version: "2.66.0", + want: ` +terraform { + required_providers { + aws = { + version = "2.66.0" + source = "hashicorp/aws" + + configuration_aliases = [ + aws.primary, + aws.secondary, + ] + } + } +} `, ok: true, },