From 39e30a8ecb32aaac15928c252d0a2968b6c467db Mon Sep 17 00:00:00 2001 From: grindlemire Date: Sun, 28 Jan 2024 17:00:05 -0700 Subject: [PATCH] fix concurrent write problem --- cmd/templ/generatecmd/main.go | 86 +++++++++++++++---- cmd/templ/generatecmd/testwatch/watch_test.go | 5 +- 2 files changed, 71 insertions(+), 20 deletions(-) diff --git a/cmd/templ/generatecmd/main.go b/cmd/templ/generatecmd/main.go index fc2652302..39301f2f5 100644 --- a/cmd/templ/generatecmd/main.go +++ b/cmd/templ/generatecmd/main.go @@ -125,7 +125,13 @@ func runCmd(ctx context.Context, w io.Writer, args Arguments) error { return generateProduction(context.Background(), w, args, opts, p) } -func generateWatched(ctx context.Context, w io.Writer, args Arguments, opts []generator.GenerateOpt, p *proxy.Handler) error { +func generateWatched( + ctx context.Context, + w io.Writer, + args Arguments, + opts []generator.GenerateOpt, + p *proxy.Handler, +) error { fmt.Fprintln(w, "Generating dev code:", args.Path) start := time.Now() @@ -136,7 +142,7 @@ func generateWatched(ctx context.Context, w io.Writer, args Arguments, opts []ge var firstRunComplete bool fileNameToLastModTime := make(map[string]time.Time) - fileNameToHash := make(map[string][sha256.Size]byte) + fileNameToHash := &sync.Map{} for !firstRunComplete || args.Watch { changesFound, errs := processChanges( @@ -155,7 +161,13 @@ func generateWatched(ctx context.Context, w io.Writer, args Arguments, opts []ge } if changesFound > 0 { if len(errs) > 0 { - logError(w, "Generated code for %d templates with %d errors in %s\n", changesFound, len(errs), time.Since(start)) + logError( + w, + "Generated code for %d templates with %d errors in %s\n", + changesFound, + len(errs), + time.Since(start), + ) } else { logSuccess(w, "Generated code for %d templates with %d errors in %s\n", changesFound, len(errs), time.Since(start)) } @@ -202,12 +214,18 @@ func generateWatched(ctx context.Context, w io.Writer, args Arguments, opts []ge return nil } -func generateProduction(ctx context.Context, w io.Writer, args Arguments, opts []generator.GenerateOpt, p *proxy.Handler) error { +func generateProduction( + ctx context.Context, + w io.Writer, + args Arguments, + opts []generator.GenerateOpt, + p *proxy.Handler, +) error { fmt.Fprintln(w, "Generating production code:", args.Path) start := time.Now() changesFound, errs := processChanges( - ctx, w, nil, nil, + ctx, w, nil, &sync.Map{}, args.Path, args.GenerateSourceMapVisualisations, opts, args.WorkerCount, false, args.KeepOrphanedFiles) if len(errs) > 0 { @@ -219,7 +237,13 @@ func generateProduction(ctx context.Context, w io.Writer, args Arguments, opts [ if changesFound > 0 { if len(errs) > 0 { - logError(w, "Generated code for %d templates with %d errors in %s\n", changesFound, len(errs), time.Since(start)) + logError( + w, + "Generated code for %d templates with %d errors in %s\n", + changesFound, + len(errs), + time.Since(start), + ) } else { logSuccess(w, "Generated code for %d templates with %d errors in %s\n", changesFound, len(errs), time.Since(start)) } @@ -249,7 +273,17 @@ func shouldSkipDir(dir string) bool { return false } -func processChanges(ctx context.Context, stdout io.Writer, fileNameToLastModTime map[string]time.Time, hashes map[string][sha256.Size]byte, path string, generateSourceMapVisualisations bool, opts []generator.GenerateOpt, maxWorkerCount int, watching, keepOrphanedFiles bool) (changesFound int, errs []error) { +func processChanges( + ctx context.Context, + stdout io.Writer, + fileNameToLastModTime map[string]time.Time, + hashes *sync.Map, + path string, + generateSourceMapVisualisations bool, + opts []generator.GenerateOpt, + maxWorkerCount int, + watching, keepOrphanedFiles bool, +) (changesFound int, errs []error) { sem := make(chan struct{}, maxWorkerCount) var wg sync.WaitGroup @@ -344,7 +378,14 @@ func openURL(w io.Writer, url string) error { // processSingleFile generates Go code for a single template. // If a basePath is provided, the filename included in error messages is relative to it. -func processSingleFile(ctx context.Context, stdout io.Writer, basePath, fileName string, hashes map[string][sha256.Size]byte, generateSourceMapVisualisations bool, opts []generator.GenerateOpt) (err error) { +func processSingleFile( + ctx context.Context, + stdout io.Writer, + basePath, fileName string, + hashes *sync.Map, + generateSourceMapVisualisations bool, + opts []generator.GenerateOpt, +) (err error) { start := time.Now() diag, err := generate(ctx, basePath, fileName, hashes, generateSourceMapVisualisations, opts) if err != nil { @@ -373,15 +414,17 @@ func printDiagnostics(w io.Writer, fileName string, diags []parser.Diagnostic) { // generate Go code for a single template. // If a basePath is provided, the filename included in error messages is relative to it. -func generate(ctx context.Context, basePath, fileName string, hashes map[string][sha256.Size]byte, generateSourceMapVisualisations bool, opts []generator.GenerateOpt) (diagnostics []parser.Diagnostic, err error) { +func generate( + ctx context.Context, + basePath, fileName string, + hashes *sync.Map, + generateSourceMapVisualisations bool, + opts []generator.GenerateOpt, +) (diagnostics []parser.Diagnostic, err error) { if err = ctx.Err(); err != nil { return } - if hashes == nil { - hashes = make(map[string][sha256.Size]byte) - } - t, err := parser.Parse(fileName) if err != nil { return nil, fmt.Errorf("%s parsing error: %w", fileName, err) @@ -407,22 +450,25 @@ func generate(ctx context.Context, basePath, fileName string, hashes map[string] // Hash output, and write out the file if the goCodeHash has changed. goCodeHash := sha256.Sum256(formattedGoCode) - if hashes[targetFileName] != goCodeHash { + + targetHash, _ := hashes.Load(targetFileName) + if targetHash != nil && targetHash.([sha256.Size]byte) != goCodeHash { if err = os.WriteFile(targetFileName, formattedGoCode, 0o644); err != nil { return nil, fmt.Errorf("failed to write target file %q: %w", targetFileName, err) } - hashes[targetFileName] = goCodeHash + hashes.Store(targetFileName, goCodeHash) } // Add the txt file if it has changed. if len(literals) > 0 { txtFileName := strings.TrimSuffix(fileName, ".templ") + "_templ.txt" txtHash := sha256.Sum256([]byte(literals)) - if hashes[txtFileName] != txtHash { + targetHash, _ := hashes.Load(txtFileName) + if targetHash != nil && targetHash.([sha256.Size]byte) != txtHash { if err = os.WriteFile(txtFileName, []byte(literals), 0o644); err != nil { return nil, fmt.Errorf("failed to write string literal file %q: %w", txtFileName, err) } - hashes[txtFileName] = txtHash + hashes.Store(txtFileName, txtHash) } } @@ -432,7 +478,11 @@ func generate(ctx context.Context, basePath, fileName string, hashes map[string] return t.Diagnostics, err } -func generateSourceMapVisualisation(ctx context.Context, templFileName, goFileName string, sourceMap *parser.SourceMap) error { +func generateSourceMapVisualisation( + ctx context.Context, + templFileName, goFileName string, + sourceMap *parser.SourceMap, +) error { if err := ctx.Err(); err != nil { return err } diff --git a/cmd/templ/generatecmd/testwatch/watch_test.go b/cmd/templ/generatecmd/testwatch/watch_test.go index 65cd3c210..de6dea287 100644 --- a/cmd/templ/generatecmd/testwatch/watch_test.go +++ b/cmd/templ/generatecmd/testwatch/watch_test.go @@ -25,6 +25,7 @@ import ( var testdata embed.FS func createTestProject(moduleRoot string) (dir string, err error) { + fmt.Printf("creating test project\n") dir, err = os.MkdirTemp("", "templ_watch_test_*") if err != nil { return dir, fmt.Errorf("failed to make test dir: %w", err) @@ -45,7 +46,7 @@ func createTestProject(moduleRoot string) (dir string, err error) { data = bytes.ReplaceAll(data, []byte("{moduleRoot}"), []byte(moduleRoot)) target = filepath.Join(dir, "go.mod") } - err = os.WriteFile(target, data, 0660) + err = os.WriteFile(target, data, 0o660) if err != nil { return dir, fmt.Errorf("failed to copy file: %w", err) } @@ -59,7 +60,7 @@ func replaceInFile(name, src, tgt string) error { return err } updated := strings.Replace(string(data), src, tgt, -1) - return os.WriteFile(name, []byte(updated), 0660) + return os.WriteFile(name, []byte(updated), 0o660) } func getPort() (port int, err error) {