diff --git a/go/tools/sizegen/sizegen.go b/go/tools/sizegen/sizegen.go index 4de3862f96d..6281cd1485e 100644 --- a/go/tools/sizegen/sizegen.go +++ b/go/tools/sizegen/sizegen.go @@ -460,58 +460,62 @@ func main() { flag.BoolVar(&verify, "verify", false, "ensure that the generated files are correct") flag.Parse() - loaded, err := packages.Load(&packages.Config{ - Mode: packages.NeedName | packages.NeedTypes | packages.NeedTypesSizes | packages.NeedTypesInfo | packages.NeedDeps | packages.NeedImports | packages.NeedModule, - Logf: log.Printf, - }, patterns...) - + result, err := GenerateSizeHelpers(patterns, generate) if err != nil { log.Fatal(err) } - sizegen, err := generateCode(loaded, generate) - if err != nil { - log.Fatal(err) - } - - result := sizegen.finalize() - if verify { - verifyFilesOnDisk(result) + for _, err := range VerifyFilesOnDisk(result) { + log.Fatal(err) + } + log.Printf("%d files OK", len(result)) } else { - saveFilesToDisk(result) - } -} - -func saveFilesToDisk(result map[string]*jen.File) { - for fullPath, file := range result { - if err := file.Save(fullPath); err != nil { - log.Fatalf("filed to save file to '%s': %v", fullPath, err) + for fullPath, file := range result { + if err := file.Save(fullPath); err != nil { + log.Fatalf("filed to save file to '%s': %v", fullPath, err) + } + log.Printf("saved '%s'", fullPath) } - log.Printf("saved '%s'", fullPath) } } -func verifyFilesOnDisk(result map[string]*jen.File) { +// VerifyFilesOnDisk compares the generated results from the codegen against the files that +// currently exist on disk and returns any mismatches +func VerifyFilesOnDisk(result map[string]*jen.File) (errors []error) { for fullPath, file := range result { existing, err := ioutil.ReadFile(fullPath) if err != nil { - log.Fatalf("missing file on disk: %s (%v)", fullPath, err) + errors = append(errors, fmt.Errorf("missing file on disk: %s (%w)", fullPath, err)) + continue } var buf bytes.Buffer if err := file.Render(&buf); err != nil { - log.Fatalf("render error for '%s': %v", fullPath, err) + errors = append(errors, fmt.Errorf("render error for '%s': %w", fullPath, err)) + continue } if !bytes.Equal(existing, buf.Bytes()) { - log.Fatalf("'%s' has changed!", fullPath) + errors = append(errors, fmt.Errorf("'%s' has changed", fullPath)) + continue } } - log.Printf("%d files OK", len(result)) + return errors } -func generateCode(loaded []*packages.Package, generate typePaths) (*sizegen, error) { +// GenerateSizeHelpers generates the auxiliary code that implements CachedSize helper methods +// for all the types listed in typePatterns +func GenerateSizeHelpers(packagePatterns []string, typePatterns []string) (map[string]*jen.File, error) { + loaded, err := packages.Load(&packages.Config{ + Mode: packages.NeedName | packages.NeedTypes | packages.NeedTypesSizes | packages.NeedTypesInfo | packages.NeedDeps | packages.NeedImports | packages.NeedModule, + Logf: log.Printf, + }, packagePatterns...) + + if err != nil { + return nil, err + } + sizegen := newSizegen(loaded[0].Module, loaded[0].TypesSizes) scopes := make(map[string]*types.Scope) @@ -519,7 +523,7 @@ func generateCode(loaded []*packages.Package, generate typePaths) (*sizegen, err scopes[pkg.PkgPath] = pkg.Types.Scope() } - for _, gen := range generate { + for _, gen := range typePatterns { pos := strings.LastIndexByte(gen, '.') if pos < 0 { return nil, fmt.Errorf("unexpected input type: %s", gen) @@ -547,5 +551,5 @@ func generateCode(loaded []*packages.Package, generate typePaths) (*sizegen, err } } - return sizegen, nil + return sizegen.finalize(), nil } diff --git a/go/tools/sizegen/sizegen_test.go b/go/tools/sizegen/sizegen_test.go index 4eb523aa383..7b549d727c4 100644 --- a/go/tools/sizegen/sizegen_test.go +++ b/go/tools/sizegen/sizegen_test.go @@ -17,68 +17,23 @@ limitations under the License. package main import ( - "io/ioutil" - "log" - "os" - "os/exec" - "path" + "fmt" "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/tools/go/packages" ) -func createFile(dir, fileName, code string) error { - s := path.Join(dir, fileName) - return ioutil.WriteFile(s, []byte(code), os.ModePerm) -} - -func TestName(t *testing.T) { - dir, err := ioutil.TempDir("", "src") - require.NoError(t, err) - command := exec.Command("go", "mod", "init", "example.com/m") - command.Dir = dir - command.Stdout = os.Stdout - command.Stderr = os.Stderr - err = command.Run() - require.NoError(t, err) - - code := ` -package code - -type A struct { - str string - field uint64 -} - -type B struct { - field1 uint64 - field2 *A -} - -` - - err = createFile(dir, "a.go", code) - require.NoError(t, err) - - config := &packages.Config{ - Mode: packages.NeedName | packages.NeedTypes | packages.NeedTypesSizes | packages.NeedTypesInfo | packages.NeedDeps | packages.NeedImports | packages.NeedModule, - Logf: log.Printf, - Dir: dir, - } - join := path.Join(dir, "...") - initial, err := packages.Load(config, join) +func TestFullGeneration(t *testing.T) { + result, err := GenerateSizeHelpers([]string{"./integration/..."}, []string{"vitess.io/vitess/go/tools/sizegen/integration.*"}) require.NoError(t, err) - pkg := initial[0] - require.Empty(t, pkg.Errors) - assert.NotNil(t, pkg.Module) - - generator, err := generateCode(initial, []string{"example.com/m.A", "example.com/m.B"}) - require.NoError(t, err) + verifyErrors := VerifyFilesOnDisk(result) + require.Empty(t, verifyErrors) - for _, file := range generator.finalize() { - t.Logf("%#v", file) + for _, file := range result { + contents := fmt.Sprintf("%#v", file) + require.Contains(t, contents, "http://www.apache.org/licenses/LICENSE-2.0") + require.Contains(t, contents, "type cachedObject interface") + require.Contains(t, contents, "//go:nocheckptr") } }