From ddad04cfb8340d0da4b5e9952d6bc1a62b5dd1d8 Mon Sep 17 00:00:00 2001 From: Hironao OTSUBO Date: Tue, 20 Sep 2022 23:50:14 +0900 Subject: [PATCH] allow updating expected vars/consts inside functions --- assert/assert_ext_test.go | 71 ++++++++++++++++++++++++++++++++++++++- internal/source/update.go | 51 +++++++++++++++++----------- 2 files changed, 102 insertions(+), 20 deletions(-) diff --git a/assert/assert_ext_test.go b/assert/assert_ext_test.go index 5903f70..d450ce2 100644 --- a/assert/assert_ext_test.go +++ b/assert/assert_ext_test.go @@ -1,6 +1,7 @@ package assert_test import ( + "go/ast" "go/parser" "go/token" "io/ioutil" @@ -56,6 +57,48 @@ expected value expected := "const expectedTwo = `this is the new\nexpected value\n`" assert.Assert(t, strings.Contains(string(raw), expected), "actual=%v", string(raw)) }) + + t.Run("var inside function is updated when -update=true", func(t *testing.T) { + patchUpdate(t) + t.Cleanup(func() { + resetVariable(t, "expectedInsideFunc", "") + }) + + actual := `this is the new +expected value +for var inside function +` + expectedInsideFunc := `` + + assert.Equal(t, actual, expectedInsideFunc) + + raw, err := ioutil.ReadFile(fileName(t)) + assert.NilError(t, err) + + expected := "expectedInsideFunc := `this is the new\nexpected value\nfor var inside function\n`" + assert.Assert(t, strings.Contains(string(raw), expected), "actual=%v", string(raw)) + }) + + t.Run("const inside function is updated when -update=true", func(t *testing.T) { + patchUpdate(t) + t.Cleanup(func() { + resetVariable(t, "expectedConstInsideFunc", "") + }) + + actual := `this is the new +expected value +for const inside function +` + const expectedConstInsideFunc = `` + + assert.Equal(t, actual, expectedConstInsideFunc) + + raw, err := ioutil.ReadFile(fileName(t)) + assert.NilError(t, err) + + expected := "const expectedConstInsideFunc = `this is the new\nexpected value\nfor const inside function\n`" + assert.Assert(t, strings.Contains(string(raw), expected), "actual=%v", string(raw)) + }) } // expectedOne is updated by running the tests with -update @@ -87,7 +130,33 @@ func resetVariable(t *testing.T, varName string, value string) { astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors|parser.ParseComments) assert.NilError(t, err) - err = source.UpdateVariable(filename, fileset, astFile, varName, value) + var ident *ast.Ident + ast.Inspect(astFile, func(n ast.Node) bool { + switch v := n.(type) { + case *ast.AssignStmt: + if len(v.Lhs) == 1 { + if id, ok := v.Lhs[0].(*ast.Ident); ok { + if id.Name == varName { + ident = id + return false + } + } + } + + case *ast.ValueSpec: + for _, id := range v.Names { + if id.Name == varName { + ident = id + return false + } + } + } + + return true + }) + assert.Assert(t, ident != nil, "failed to get ident for %s", varName) + + err = source.UpdateVariable(filename, fileset, astFile, ident, value) assert.NilError(t, err, "failed to reset file") } diff --git a/internal/source/update.go b/internal/source/update.go index bd9678b..67ed7c8 100644 --- a/internal/source/update.go +++ b/internal/source/update.go @@ -54,8 +54,8 @@ func UpdateExpectedValue(stackIndex int, x, y interface{}) error { return ErrNotFound } - argIndex, varName := getVarNameForExpectedValueArg(expr) - if argIndex < 0 || varName == "" { + argIndex, ident := getVarNameForExpectedValueArg(expr) + if argIndex < 0 || ident == nil { debug("no arguments started with the word 'expected': %v", debugFormatNode{Node: &ast.CallExpr{Args: expr}}) return ErrNotFound @@ -71,7 +71,7 @@ func UpdateExpectedValue(stackIndex int, x, y interface{}) error { debug("value must be type string, got %T", value) return ErrNotFound } - return UpdateVariable(filename, fileset, astFile, varName, strValue) + return UpdateVariable(filename, fileset, astFile, ident, strValue) } // UpdateVariable writes to filename the contents of astFile with the value of @@ -80,10 +80,10 @@ func UpdateVariable( filename string, fileset *token.FileSet, astFile *ast.File, - varName string, + ident *ast.Ident, value string, ) error { - obj := astFile.Scope.Objects[varName] + obj := ident.Obj if obj == nil { return ErrNotFound } @@ -92,20 +92,33 @@ func UpdateVariable( return ErrNotFound } - spec, ok := obj.Decl.(*ast.ValueSpec) - if !ok { + switch decl := obj.Decl.(type) { + case *ast.ValueSpec: + if len(decl.Names) != 1 { + debug("more than one name in ast.ValueSpec") + return ErrNotFound + } + + decl.Values[0] = &ast.BasicLit{ + Kind: token.STRING, + Value: "`" + value + "`", + } + + case *ast.AssignStmt: + if len(decl.Lhs) != 1 { + debug("more than one name in ast.AssignStmt") + return ErrNotFound + } + + decl.Rhs[0] = &ast.BasicLit{ + Kind: token.STRING, + Value: "`" + value + "`", + } + + default: debug("can only update *ast.ValueSpec, found %T", obj.Decl) return ErrNotFound } - if len(spec.Names) != 1 { - debug("more than one name in ast.ValueSpec") - return ErrNotFound - } - - spec.Values[0] = &ast.BasicLit{ - Kind: token.STRING, - Value: "`" + value + "`", - } var buf bytes.Buffer if err := format.Node(&buf, fileset, astFile); err != nil { @@ -125,14 +138,14 @@ func UpdateVariable( return nil } -func getVarNameForExpectedValueArg(expr []ast.Expr) (int, string) { +func getVarNameForExpectedValueArg(expr []ast.Expr) (int, *ast.Ident) { for i := 1; i < 3; i++ { switch e := expr[i].(type) { case *ast.Ident: if strings.HasPrefix(strings.ToLower(e.Name), "expected") { - return i, e.Name + return i, e } } } - return -1, "" + return -1, nil }