diff --git a/main.go b/main.go index b2ef240..60360e8 100644 --- a/main.go +++ b/main.go @@ -114,35 +114,75 @@ func main() { os.Exit(1) } + // Get non-flag command-line args + patterns := flag.Args() + // convert -skip flags to -ignore equivalents for _, s := range skipExtensionFlags { ignorePatterns = append(ignorePatterns, fmt.Sprintf("**/*.%s", s)) } + + // create logger to print updates to stdout + logger := log.Default() + + // real main + err := Run( + ignorePatterns, + spdx, + *holder, + *license, + *licensef, + *year, + *verbose, + *checkonly, + patterns, + logger, + ) + + if err != nil { + log.Fatal(err) + } +} + +// Run executes addLicense with supplied variables +func Run( + ignorePatterns stringSlice, + spdx spdxFlag, + holder string, + license string, + licensef string, + year string, + verbose bool, + checkonly bool, + patterns []string, + logger *log.Logger, +) error { + // verify that all ignorePatterns are valid for _, p := range ignorePatterns { if !doublestar.ValidatePattern(p) { - log.Fatalf("-ignore pattern %q is not valid", p) + return fmt.Errorf("-ignore pattern %q is not valid", p) } } // map legacy license values - if t, ok := legacyLicenseTypes[*license]; ok { - *license = t + if t, ok := legacyLicenseTypes[license]; ok { + license = t } data := licenseData{ - Year: *year, - Holder: *holder, - SPDXID: *license, + Year: year, + Holder: holder, + SPDXID: license, } - tpl, err := fetchTemplate(*license, *licensef, spdx) + tpl, err := fetchTemplate(license, licensef, spdx) if err != nil { - log.Fatal(err) + return err } t, err := template.New("").Parse(tpl) if err != nil { - log.Fatal(err) + return err } // process at most 1000 files in parallel @@ -153,11 +193,11 @@ func main() { for f := range ch { f := f // https://golang.org/doc/faq#closures_and_goroutines wg.Go(func() error { - if *checkonly { + if checkonly { // Check if file extension is known lic, err := licenseHeader(f.path, t, data) if err != nil { - log.Printf("%s: %v", f.path, err) + logger.Printf("%s: %v", f.path, err) return err } if lic == nil { // Unknown fileExtension @@ -166,21 +206,21 @@ func main() { // Check if file has a license hasLicense, err := fileHasLicense(f.path) if err != nil { - log.Printf("%s: %v", f.path, err) + logger.Printf("%s: %v", f.path, err) return err } if !hasLicense { - fmt.Printf("%s\n", f.path) + logger.Printf("%s\n", f.path) return errors.New("missing license header") } } else { modified, err := addLicense(f.path, f.mode, t, data) if err != nil { - log.Printf("%s: %v", f.path, err) + logger.Printf("%s: %v", f.path, err) return err } - if *verbose && modified { - log.Printf("%s modified", f.path) + if verbose && modified { + logger.Printf("%s modified", f.path) } } return nil @@ -193,13 +233,15 @@ func main() { } }() - for _, d := range flag.Args() { - if err := walk(ch, d); err != nil { - log.Fatal(err) + for _, d := range patterns { + if err := walk(ch, d, logger); err != nil { + return err } } close(ch) <-done + + return nil } type file struct { @@ -207,17 +249,17 @@ type file struct { mode os.FileMode } -func walk(ch chan<- *file, start string) error { +func walk(ch chan<- *file, start string, logger *log.Logger) error { return filepath.Walk(start, func(path string, fi os.FileInfo, err error) error { if err != nil { - log.Printf("%s error: %v", path, err) + logger.Printf("%s error: %v", path, err) return nil } if fi.IsDir() { return nil } if fileMatches(path, ignorePatterns) { - log.Printf("skipping: %s", path) + logger.Printf("skipping: %s", path) return nil } ch <- &file{path, fi.Mode()}