From 1bf6faa9dcaadd893005f7deda1d734925cb16e6 Mon Sep 17 00:00:00 2001 From: Adrian Hesketh Date: Tue, 30 Jan 2024 08:45:08 +0000 Subject: [PATCH] feat: add fsnotify-based hot reload (#470) --- .version | 2 +- cmd/templ/generatecmd/cmd.go | 323 +++++++++++++++++ cmd/templ/generatecmd/eventhandler.go | 233 ++++++++++++ cmd/templ/generatecmd/fatalerror.go | 23 ++ cmd/templ/generatecmd/main.go | 478 ++----------------------- cmd/templ/generatecmd/sse/server.go | 3 +- cmd/templ/generatecmd/watcher/watch.go | 129 +++++++ cmd/templ/main.go | 4 + cmd/templ/sloghandler/handler.go | 101 ++++++ examples/counter-basic/main.go | 2 +- flake.nix | 2 +- go.mod | 1 + go.sum | 2 + 13 files changed, 842 insertions(+), 461 deletions(-) create mode 100644 cmd/templ/generatecmd/cmd.go create mode 100644 cmd/templ/generatecmd/eventhandler.go create mode 100644 cmd/templ/generatecmd/fatalerror.go create mode 100644 cmd/templ/generatecmd/watcher/watch.go create mode 100644 cmd/templ/sloghandler/handler.go diff --git a/.version b/.version index 136c78f7c..a212db8ff 100644 --- a/.version +++ b/.version @@ -1 +1 @@ -0.2.546 \ No newline at end of file +0.2.549 \ No newline at end of file diff --git a/cmd/templ/generatecmd/cmd.go b/cmd/templ/generatecmd/cmd.go new file mode 100644 index 000000000..aa89088c7 --- /dev/null +++ b/cmd/templ/generatecmd/cmd.go @@ -0,0 +1,323 @@ +package generatecmd + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net/http" + "net/url" + "path" + "path/filepath" + "runtime" + "sync" + "time" + + "github.com/a-h/templ" + "github.com/a-h/templ/cmd/templ/generatecmd/modcheck" + "github.com/a-h/templ/cmd/templ/generatecmd/proxy" + "github.com/a-h/templ/cmd/templ/generatecmd/run" + "github.com/a-h/templ/cmd/templ/generatecmd/watcher" + "github.com/a-h/templ/generator" + "github.com/cenkalti/backoff/v4" + "github.com/cli/browser" + "github.com/fsnotify/fsnotify" +) + +func NewGenerate(log *slog.Logger, args Arguments) (g *Generate) { + g = &Generate{ + Log: log, + Args: &args, + } + if g.Args.WorkerCount == 0 { + g.Args.WorkerCount = runtime.NumCPU() + } + return g +} + +type Generate struct { + Log *slog.Logger + Args *Arguments +} + +type GenerationEvent struct { + Event fsnotify.Event + GoUpdated bool + TextUpdated bool +} + +func (cmd Generate) Run(ctx context.Context) (err error) { + if cmd.Args.Watch && cmd.Args.FileName != "" { + return fmt.Errorf("cannot watch a single file, remove the -f or -watch flag") + } + + if cmd.Args.PPROFPort > 0 { + go func() { + _ = http.ListenAndServe(fmt.Sprintf("localhost:%d", cmd.Args.PPROFPort), nil) + }() + } + + // Use absolute path. + if !path.IsAbs(cmd.Args.Path) { + cmd.Args.Path, err = filepath.Abs(cmd.Args.Path) + if err != nil { + return fmt.Errorf("failed to get absolute path: %w", err) + } + } + + // Configure generator. + var opts []generator.GenerateOpt + if cmd.Args.IncludeVersion { + opts = append(opts, generator.WithVersion(templ.Version())) + } + if cmd.Args.IncludeTimestamp { + opts = append(opts, generator.WithTimestamp(time.Now())) + } + + // Check the version of the templ module. + if err := modcheck.Check(cmd.Args.Path); err != nil { + cmd.Log.Warn("templ version check: " + err.Error()) + } + + fseh := NewFSEventHandler(cmd.Log, cmd.Args.Path, cmd.Args.Watch, opts, cmd.Args.GenerateSourceMapVisualisations, cmd.Args.KeepOrphanedFiles) + + // If we're processing a single file, don't bother setting up the channels/multithreaing. + if cmd.Args.FileName != "" { + _, _, err = fseh.HandleEvent(ctx, fsnotify.Event{ + Name: cmd.Args.FileName, + Op: fsnotify.Create, + }) + return err + } + + // Start timer. + start := time.Now() + + // Create channels: + // For the initial filesystem walk and subsequent (optional) fsnotify events. + events := make(chan fsnotify.Event) + // count of events currently being processed by the event handler. + var eventsWG sync.WaitGroup + // Used to check that the event handler has completed. + var eventHandlerWG sync.WaitGroup + // For errs from the watcher. + errs := make(chan error) + // For triggering actions after generation has completed. + postGeneration := make(chan *GenerationEvent, 256) + // Used to check that the post-generation handler has completed. + var postGenerationWG sync.WaitGroup + var postGenerationEventsWG sync.WaitGroup + + // Waitgroup for the push process. + var pushHandlerWG sync.WaitGroup + + // Start process to push events into the channel. + pushHandlerWG.Add(1) + go func() { + defer pushHandlerWG.Done() + defer close(events) + defer close(errs) + cmd.Log.Debug("Walking directory", slog.String("path", cmd.Args.Path), slog.Bool("devMode", cmd.Args.Watch)) + if err := watcher.WalkFiles(ctx, cmd.Args.Path, events); err != nil { + cmd.Log.Error("WalkFiles failed, exiting", slog.Any("error", err)) + errs <- FatalError{Err: fmt.Errorf("failed to walk files: %w", err)} + return + } + if !cmd.Args.Watch { + cmd.Log.Debug("Dev mode not enabled, process can finish early") + return + } + cmd.Log.Info("Watching files") + rw, err := watcher.Recursive(ctx, cmd.Args.Path, events, errs) + if err != nil { + cmd.Log.Error("Recursive watcher setup failed, exiting", slog.Any("error", err)) + errs <- FatalError{Err: fmt.Errorf("failed to setup recursive watcher: %w", err)} + return + } + cmd.Log.Debug("Waiting for context to be cancelled to stop watching files") + <-ctx.Done() + cmd.Log.Debug("Context cancelled, closing watcher") + if err := rw.Close(); err != nil { + cmd.Log.Error("Failed to close watcher", slog.Any("error", err)) + } + cmd.Log.Debug("Waiting for events to be processed") + eventsWG.Wait() + cmd.Log.Debug("All pending events processed, waiting for pending post-generation events to complete") + postGenerationEventsWG.Wait() + cmd.Log.Debug("All post-generation events processed, running walk again, but in production mode") + fseh.DevMode = false + if err := watcher.WalkFiles(ctx, cmd.Args.Path, events); err != nil { + cmd.Log.Error("Post dev mode WalkFiles failed", slog.Any("error", err)) + errs <- FatalError{Err: fmt.Errorf("failed to walk files: %w", err)} + return + } + }() + + // Start process to handle events. + eventHandlerWG.Add(1) + sem := make(chan struct{}, cmd.Args.WorkerCount) + go func() { + defer eventHandlerWG.Done() + defer close(postGeneration) + cmd.Log.Debug("Starting event handler") + for event := range events { + eventsWG.Add(1) + sem <- struct{}{} + go func(event fsnotify.Event) { + cmd.Log.Debug("Processing file", slog.String("file", event.Name)) + defer eventsWG.Done() + defer func() { <-sem }() + goUpdated, textUpdated, err := fseh.HandleEvent(ctx, event) + if err != nil { + cmd.Log.Error("Event handler failed", slog.Any("error", err)) + errs <- err + } + if goUpdated || textUpdated { + postGeneration <- &GenerationEvent{ + Event: event, + GoUpdated: goUpdated, + TextUpdated: textUpdated, + } + } + }(event) + } + // Wait for all events to be processed before closing. + eventsWG.Wait() + }() + + // Start process to handle post-generation events. + var updates int + postGenerationWG.Add(1) + var firstPostGenerationExecuted bool + go func() { + defer postGenerationWG.Done() + cmd.Log.Debug("Starting post-generation handler") + timeout := time.NewTimer(time.Hour * 24 * 365) + var goUpdated, textUpdated bool + var p *proxy.Handler + for { + select { + case ge := <-postGeneration: + if ge == nil { + cmd.Log.Debug("Post-generation event channel closed, exiting") + return + } + goUpdated = goUpdated || ge.GoUpdated + textUpdated = textUpdated || ge.TextUpdated + if goUpdated || textUpdated { + updates++ + } + // Reset timer. + if !timeout.Stop() { + <-timeout.C + } + timeout.Reset(time.Millisecond * 100) + case <-timeout.C: + if !goUpdated && !textUpdated { + // Nothing to process, reset timer and wait again. + timeout.Reset(time.Hour * 24 * 365) + break + } + postGenerationEventsWG.Add(1) + if cmd.Args.Command != "" && goUpdated { + cmd.Log.Debug("Executing command", slog.String("command", cmd.Args.Command)) + if _, err := run.Run(ctx, cmd.Args.Path, cmd.Args.Command); err != nil { + cmd.Log.Error("Error executing command", slog.Any("error", err)) + } + } + if !firstPostGenerationExecuted { + cmd.Log.Debug("First post-generation event received, starting proxy") + firstPostGenerationExecuted = true + p, err = cmd.StartProxy(ctx) + if err != nil { + cmd.Log.Error("Failed to start proxy", slog.Any("error", err)) + } + } + // Send server-sent event. + if p != nil && (textUpdated || goUpdated) { + cmd.Log.Debug("Sending reload event") + p.SendSSE("message", "reload") + } + postGenerationEventsWG.Done() + // Reset timer. + timeout.Reset(time.Millisecond * 100) + textUpdated = false + goUpdated = false + } + } + }() + + // Read errors. + for err := range errs { + if err == nil { + continue + } + if errors.Is(err, FatalError{}) { + cmd.Log.Debug("Fatal error, exiting") + return err + } + cmd.Log.Error("Error received", slog.Any("error", err)) + } + + // Wait for everything to complete. + cmd.Log.Debug("Waiting for push handler to complete") + pushHandlerWG.Wait() + cmd.Log.Debug("Waiting for event handler to complete") + eventHandlerWG.Wait() + cmd.Log.Debug("Waiting for post-generation handler to complete") + postGenerationWG.Wait() + if cmd.Args.Command != "" { + cmd.Log.Debug("Killing command", slog.String("command", cmd.Args.Command)) + if err := run.KillAll(); err != nil { + cmd.Log.Error("Error killing command", slog.Any("error", err)) + } + } + cmd.Log.Info("Complete", slog.Int("updates", updates), slog.Duration("duration", time.Since(start))) + + return nil +} + +func (cmd *Generate) StartProxy(ctx context.Context) (p *proxy.Handler, err error) { + if cmd.Args.Proxy == "" { + cmd.Log.Debug("No proxy URL specified, not starting proxy") + return nil, nil + } + var target *url.URL + target, err = url.Parse(cmd.Args.Proxy) + if err != nil { + return nil, FatalError{Err: fmt.Errorf("failed to parse proxy URL: %w", err)} + } + if cmd.Args.ProxyPort == 0 { + cmd.Args.ProxyPort = 7331 + } + p = proxy.New(cmd.Args.ProxyPort, target) + go func() { + cmd.Log.Info("Proxying", slog.String("from", p.URL), slog.String("to", p.Target.String())) + if err := http.ListenAndServe(fmt.Sprintf("127.0.0.1:%d", cmd.Args.ProxyPort), p); err != nil { + cmd.Log.Error("Proxy failed", slog.Any("error", err)) + } + }() + if !cmd.Args.OpenBrowser { + cmd.Log.Debug("Not opening browser") + return p, nil + } + go func() { + cmd.Log.Debug("Waiting for proxy to be ready", slog.String("url", p.URL)) + backoff := backoff.NewExponentialBackOff() + backoff.InitialInterval = time.Second + var client http.Client + client.Timeout = 1 * time.Second + for { + if _, err := client.Get(p.URL); err == nil { + break + } + d := backoff.NextBackOff() + cmd.Log.Debug("Proxy not ready, retrying", slog.String("url", p.URL), slog.Any("backoff", d)) + time.Sleep(d) + } + if err := browser.OpenURL(p.URL); err != nil { + cmd.Log.Error("Failed to open browser", slog.Any("error", err)) + } + }() + return p, nil +} diff --git a/cmd/templ/generatecmd/eventhandler.go b/cmd/templ/generatecmd/eventhandler.go new file mode 100644 index 000000000..eb8ef7e13 --- /dev/null +++ b/cmd/templ/generatecmd/eventhandler.go @@ -0,0 +1,233 @@ +package generatecmd + +import ( + "bufio" + "bytes" + "context" + "crypto/sha256" + "fmt" + "go/format" + "log/slog" + "os" + "path" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/a-h/templ/cmd/templ/visualize" + "github.com/a-h/templ/generator" + "github.com/a-h/templ/parser/v2" + "github.com/fsnotify/fsnotify" +) + +func NewFSEventHandler(log *slog.Logger, dir string, devMode bool, genOpts []generator.GenerateOpt, genSourceMapVis bool, keepOrphanedFiles bool) *FSEventHandler { + if !path.IsAbs(dir) { + dir, _ = filepath.Abs(dir) + } + fseh := &FSEventHandler{ + Log: log, + dir: dir, + fileNameToLastModTime: make(map[string]time.Time), + fileNameToLastModTimeMutex: &sync.Mutex{}, + hashes: make(map[string][sha256.Size]byte), + hashesMutex: &sync.Mutex{}, + genOpts: genOpts, + genSourceMapVis: genSourceMapVis, + DevMode: devMode, + keepOrphanedFiles: keepOrphanedFiles, + } + if devMode { + fseh.genOpts = append(fseh.genOpts, generator.WithExtractStrings()) + } + return fseh +} + +type FSEventHandler struct { + Log *slog.Logger + // dir is the root directory being processed. + dir string + fileNameToLastModTime map[string]time.Time + fileNameToLastModTimeMutex *sync.Mutex + hashes map[string][sha256.Size]byte + hashesMutex *sync.Mutex + genOpts []generator.GenerateOpt + genSourceMapVis bool + DevMode bool + keepOrphanedFiles bool +} + +func (h *FSEventHandler) HandleEvent(ctx context.Context, event fsnotify.Event) (goUpdated, textUpdated bool, err error) { + // Handle _templ.go files. + if !event.Has(fsnotify.Remove) && strings.HasSuffix(event.Name, "_templ.go") { + _, err = os.Stat(strings.TrimSuffix(event.Name, "_templ.go") + ".templ") + if !os.IsNotExist(err) { + return false, false, err + } + // File is orphaned. + if h.keepOrphanedFiles { + return false, false, nil + } + h.Log.Debug("Deleting orphaned Go file", slog.String("file", event.Name)) + if err = os.Remove(event.Name); err != nil { + h.Log.Warn("Failed to remove orphaned file", slog.Any("error", err)) + } + return true, false, nil + } + // Handle _templ.txt files. + if !event.Has(fsnotify.Remove) && strings.HasSuffix(event.Name, "_templ.txt") { + if h.DevMode { + // Don't delete the file if we're in dev mode, but mark that text was updated. + return false, true, nil + } + h.Log.Debug("Deleting watch mode file", slog.String("file", event.Name)) + if err = os.Remove(event.Name); err != nil { + h.Log.Warn("Failed to remove watch mode text file", slog.Any("error", err)) + return false, false, nil + } + return false, false, nil + } + + // Handle .templ files. + if !strings.HasSuffix(event.Name, ".templ") { + return false, false, nil + } + + // If the file hasn't been updated since the last time we processed it, ignore it. + if !h.UpsertLastModTime(event.Name) { + return false, false, nil + } + + // Start a processor. + start := time.Now() + goUpdated, textUpdated, diag, err := h.generate(ctx, event.Name) + if err != nil { + h.Log.Error("Error generating code", slog.String("file", event.Name), slog.Any("error", err)) + return goUpdated, textUpdated, fmt.Errorf("failed to generate code for %q: %w", event.Name, err) + } + if len(diag) > 0 { + for _, d := range diag { + h.Log.Warn(d.Message, slog.String("from", fmt.Sprintf("%d:%d", d.Range.From.Line, d.Range.From.Col)), slog.String("to", fmt.Sprintf("%d:%d", d.Range.To.Line, d.Range.To.Col))) + } + return + } + h.Log.Debug("Generated code", slog.String("file", event.Name), slog.Duration("in", time.Since(start))) + + return goUpdated, textUpdated, nil +} + +func (h *FSEventHandler) UpsertLastModTime(fileName string) (updated bool) { + fileInfo, err := os.Stat(fileName) + if err != nil { + return false + } + h.fileNameToLastModTimeMutex.Lock() + defer h.fileNameToLastModTimeMutex.Unlock() + lastModTime := h.fileNameToLastModTime[fileName] + if !fileInfo.ModTime().After(lastModTime) { + return false + } + h.fileNameToLastModTime[fileName] = fileInfo.ModTime() + return true +} + +func (h *FSEventHandler) UpsertHash(fileName string, hash [sha256.Size]byte) (updated bool) { + h.hashesMutex.Lock() + defer h.hashesMutex.Unlock() + lastHash := h.hashes[fileName] + if lastHash == hash { + return false + } + h.hashes[fileName] = hash + return true +} + +// generate Go code for a single template. +// If a basePath is provided, the filename included in error messages is relative to it. +func (h *FSEventHandler) generate(ctx context.Context, fileName string) (goUpdated, textUpdated bool, diagnostics []parser.Diagnostic, err error) { + t, err := parser.Parse(fileName) + if err != nil { + return false, false, nil, fmt.Errorf("%s parsing error: %w", fileName, err) + } + targetFileName := strings.TrimSuffix(fileName, ".templ") + "_templ.go" + + // Only use relative filenames to the basepath for filenames in runtime error messages. + relFilePath, err := filepath.Rel(h.dir, fileName) + if err != nil { + return false, false, nil, fmt.Errorf("failed to get relative path for %q: %w", fileName, err) + } + + var b bytes.Buffer + sourceMap, literals, err := generator.Generate(t, &b, append(h.genOpts, generator.WithFileName(relFilePath))...) + if err != nil { + return false, false, nil, fmt.Errorf("%s generation error: %w", fileName, err) + } + + formattedGoCode, err := format.Source(b.Bytes()) + if err != nil { + return false, false, nil, fmt.Errorf("%s source formatting error: %w", fileName, err) + } + + // Hash output, and write out the file if the goCodeHash has changed. + goCodeHash := sha256.Sum256(formattedGoCode) + if h.UpsertHash(targetFileName, goCodeHash) { + goUpdated = true + if err = os.WriteFile(targetFileName, formattedGoCode, 0o644); err != nil { + return false, false, nil, fmt.Errorf("failed to write target file %q: %w", targetFileName, err) + } + } + + // Add the txt file if it has changed. + if len(literals) > 0 { + txtFileName := strings.TrimSuffix(fileName, ".templ") + "_templ.txt" + txtHash := sha256.Sum256([]byte(literals)) + if h.UpsertHash(txtFileName, txtHash) { + textUpdated = true + if err = os.WriteFile(txtFileName, []byte(literals), 0o644); err != nil { + return false, false, nil, fmt.Errorf("failed to write string literal file %q: %w", txtFileName, err) + } + } + } + + if h.genSourceMapVis { + err = generateSourceMapVisualisation(ctx, fileName, targetFileName, sourceMap) + } + + return goUpdated, textUpdated, t.Diagnostics, err +} + +func generateSourceMapVisualisation(ctx context.Context, templFileName, goFileName string, sourceMap *parser.SourceMap) error { + if err := ctx.Err(); err != nil { + return err + } + var templContents, goContents []byte + var templErr, goErr error + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + templContents, templErr = os.ReadFile(templFileName) + }() + go func() { + defer wg.Done() + goContents, goErr = os.ReadFile(goFileName) + }() + wg.Wait() + if templErr != nil { + return templErr + } + if goErr != nil { + return templErr + } + + targetFileName := strings.TrimSuffix(templFileName, ".templ") + "_templ_sourcemap.html" + w, err := os.Create(targetFileName) + if err != nil { + return fmt.Errorf("%s sourcemap visualisation error: %w", templFileName, err) + } + defer w.Close() + b := bufio.NewWriter(w) + defer b.Flush() + + return visualize.HTML(templFileName, string(templContents), string(goContents), sourceMap).Render(ctx, b) +} diff --git a/cmd/templ/generatecmd/fatalerror.go b/cmd/templ/generatecmd/fatalerror.go new file mode 100644 index 000000000..e1092df4c --- /dev/null +++ b/cmd/templ/generatecmd/fatalerror.go @@ -0,0 +1,23 @@ +package generatecmd + +type FatalError struct { + Err error +} + +func (e FatalError) Error() string { + return e.Err.Error() +} + +func (e FatalError) Unwrap() error { + return e.Err +} + +func (e FatalError) Is(target error) bool { + _, ok := target.(FatalError) + return ok +} + +func (e FatalError) As(target interface{}) bool { + _, ok := target.(*FatalError) + return ok +} diff --git a/cmd/templ/generatecmd/main.go b/cmd/templ/generatecmd/main.go index fc2652302..0669ccf5f 100644 --- a/cmd/templ/generatecmd/main.go +++ b/cmd/templ/generatecmd/main.go @@ -1,37 +1,14 @@ package generatecmd import ( - "bufio" - "bytes" "context" - "crypto/sha256" _ "embed" - "errors" - "fmt" - "go/format" "io" - "net/http" - "net/url" - "os" - "path" - "path/filepath" - "runtime" - "strings" - "sync" - "time" + "log/slog" _ "net/http/pprof" - "github.com/a-h/templ" - "github.com/a-h/templ/cmd/templ/generatecmd/modcheck" - "github.com/a-h/templ/cmd/templ/generatecmd/proxy" - "github.com/a-h/templ/cmd/templ/generatecmd/run" - "github.com/a-h/templ/cmd/templ/visualize" - "github.com/a-h/templ/generator" - "github.com/a-h/templ/parser/v2" - "github.com/cenkalti/backoff/v4" - "github.com/cli/browser" - "github.com/fatih/color" + "github.com/a-h/templ/cmd/templ/sloghandler" ) type Arguments struct { @@ -46,441 +23,28 @@ type Arguments struct { GenerateSourceMapVisualisations bool IncludeVersion bool IncludeTimestamp bool + Level string // PPROFPort is the port to run the pprof server on. PPROFPort int KeepOrphanedFiles bool } -var defaultWorkerCount = runtime.NumCPU() - -func Run(ctx context.Context, w io.Writer, args Arguments) (err error) { - if args.PPROFPort > 0 { - go func() { - _ = http.ListenAndServe(fmt.Sprintf("localhost:%d", args.PPROFPort), nil) - }() - } - - err = runCmd(ctx, w, args) - if errors.Is(err, context.Canceled) { - return nil - } - - return err -} - -func runCmd(ctx context.Context, w io.Writer, args Arguments) error { - var err error - - if args.Watch && args.FileName != "" { - return fmt.Errorf("cannot watch a single file, remove the -f or -watch flag") - } - var opts []generator.GenerateOpt - if args.IncludeVersion { - opts = append(opts, generator.WithVersion(templ.Version())) - } - if args.IncludeTimestamp { - opts = append(opts, generator.WithTimestamp(time.Now())) - } - if args.FileName != "" { - return processSingleFile(ctx, w, "", args.FileName, nil, args.GenerateSourceMapVisualisations, opts) - } - var target *url.URL - if args.Proxy != "" { - target, err = url.Parse(args.Proxy) - if err != nil { - return fmt.Errorf("failed to parse proxy URL: %w", err) - } - } - if args.ProxyPort == 0 { - args.ProxyPort = 7331 - } - - if args.WorkerCount == 0 { - args.WorkerCount = defaultWorkerCount - } - if !path.IsAbs(args.Path) { - args.Path, err = filepath.Abs(args.Path) - if err != nil { - return err - } - } - - var p *proxy.Handler - if args.Proxy != "" { - p = proxy.New(args.ProxyPort, target) - } - fmt.Fprintln(w, "Processing path:", args.Path) - - if err := modcheck.Check(args.Path); err != nil { - logWarning(w, "templ version check failed: %v\n", err) - } - - if args.Watch { - err = generateWatched(ctx, w, args, opts, p) - if err != nil && !errors.Is(err, context.Canceled) { - return err - } - } - - return generateProduction(context.Background(), w, args, opts, p) -} - -func generateWatched(ctx context.Context, w io.Writer, args Arguments, opts []generator.GenerateOpt, p *proxy.Handler) error { - fmt.Fprintln(w, "Generating dev code:", args.Path) - start := time.Now() - - bo := backoff.NewExponentialBackOff() - bo.InitialInterval = time.Millisecond * 500 - bo.MaxInterval = time.Second * 3 - bo.MaxElapsedTime = 0 - - var firstRunComplete bool - fileNameToLastModTime := make(map[string]time.Time) - fileNameToHash := make(map[string][sha256.Size]byte) - - for !firstRunComplete || args.Watch { - changesFound, errs := processChanges( - ctx, w, - fileNameToLastModTime, fileNameToHash, - args.Path, args.GenerateSourceMapVisualisations, - opts, args.WorkerCount, true, args.KeepOrphanedFiles) - if len(errs) > 0 { - if errors.Is(errs[0], context.Canceled) { - return errs[0] - } - if !args.Watch { - return fmt.Errorf("failed to process path: %v", errors.Join(errs...)) - } - logError(w, "Error processing path: %v\n", errors.Join(errs...)) - } - if changesFound > 0 { - if len(errs) > 0 { - logError(w, "Generated code for %d templates with %d errors in %s\n", changesFound, len(errs), time.Since(start)) - } else { - logSuccess(w, "Generated code for %d templates with %d errors in %s\n", changesFound, len(errs), time.Since(start)) - } - if args.Command != "" { - fmt.Fprintf(w, "Executing command: %s\n", args.Command) - if _, err := run.Run(ctx, args.Path, args.Command); err != nil { - fmt.Fprintf(w, "Error starting command: %v\n", err) - } - } - // Send server-sent event. - if p != nil { - p.SendSSE("message", "reload") - } - - if !firstRunComplete && p != nil { - go func() { - fmt.Fprintf(w, "Proxying from %s to target: %s\n", p.URL, p.Target.String()) - if err := http.ListenAndServe(fmt.Sprintf("127.0.0.1:%d", args.ProxyPort), p); err != nil { - fmt.Fprintf(w, "Error starting proxy: %v\n", err) - } - }() - if args.OpenBrowser { - go func() { - fmt.Fprintf(w, "Opening URL: %s\n", p.Target.String()) - if err := openURL(w, p.URL); err != nil { - fmt.Fprintf(w, "Error opening URL: %v\n", err) - } - }() - } - } - } - - if firstRunComplete { - if changesFound > 0 { - bo.Reset() - } - time.Sleep(bo.NextBackOff()) - } - - firstRunComplete = true - start = time.Now() - } - - return nil -} - -func generateProduction(ctx context.Context, w io.Writer, args Arguments, opts []generator.GenerateOpt, p *proxy.Handler) error { - fmt.Fprintln(w, "Generating production code:", args.Path) - start := time.Now() - - changesFound, errs := processChanges( - ctx, w, nil, nil, - args.Path, args.GenerateSourceMapVisualisations, - opts, args.WorkerCount, false, args.KeepOrphanedFiles) - if len(errs) > 0 { - if errors.Is(errs[0], context.Canceled) { - return errs[0] - } - logError(w, "Error processing path: %v\n", errors.Join(errs...)) - } - - if changesFound > 0 { - if len(errs) > 0 { - logError(w, "Generated code for %d templates with %d errors in %s\n", changesFound, len(errs), time.Since(start)) - } else { - logSuccess(w, "Generated code for %d templates with %d errors in %s\n", changesFound, len(errs), time.Since(start)) - } - if args.Command != "" { - fmt.Fprintf(w, "Executing command: %s\n", args.Command) - if _, err := run.Run(ctx, args.Path, args.Command); err != nil { - fmt.Fprintf(w, "Error starting command: %v\n", err) - } - } - } - - return nil -} - -func shouldSkipDir(dir string) bool { - if dir == "." { - return false - } - if dir == "vendor" || dir == "node_modules" { - return true - } - _, name := path.Split(dir) - // These directories are ignored by the Go tool. - if strings.HasPrefix(name, ".") || strings.HasPrefix(name, "_") { - return true - } - return false -} - -func processChanges(ctx context.Context, stdout io.Writer, fileNameToLastModTime map[string]time.Time, hashes map[string][sha256.Size]byte, path string, generateSourceMapVisualisations bool, opts []generator.GenerateOpt, maxWorkerCount int, watching, keepOrphanedFiles bool) (changesFound int, errs []error) { - sem := make(chan struct{}, maxWorkerCount) - var wg sync.WaitGroup - - if watching { - opts = append(opts, generator.WithExtractStrings()) - } - - if fileNameToLastModTime == nil { - fileNameToLastModTime = make(map[string]time.Time) - } - - err := filepath.WalkDir(path, func(fileName string, info os.DirEntry, err error) error { - if err != nil { - return err - } - if err = ctx.Err(); err != nil { - return err - } - if info.IsDir() && shouldSkipDir(fileName) { - return filepath.SkipDir - } - if info.IsDir() { - return nil - } - - orphaned := !keepOrphanedFiles && strings.HasSuffix(fileName, "_templ.go") - if orphaned { - // Make sure the generated file is orphaned - // by checking if the corresponding .templ file exists. - if _, err := os.Stat(strings.TrimSuffix(fileName, "_templ.go") + ".templ"); err == nil { - orphaned = false - } - } - - devTextFile := !watching && strings.HasSuffix(fileName, "_templ.txt") - if orphaned || devTextFile { - if err = os.Remove(fileName); err != nil { - return fmt.Errorf("failed to remove file: %w", err) - } - logWarning(stdout, "Deleted file %q\n", fileName) - return nil - } - - if strings.HasSuffix(fileName, ".templ") { - lastModTime := fileNameToLastModTime[fileName] - fileInfo, err := info.Info() - if err != nil { - return fmt.Errorf("failed to get file info: %w", err) - } - if fileInfo.ModTime().After(lastModTime) { - fileNameToLastModTime[fileName] = fileInfo.ModTime() - changesFound++ - - // Start a processor, but limit to maxWorkerCount. - sem <- struct{}{} - wg.Add(1) - go func() { - defer wg.Done() - if err := processSingleFile(ctx, stdout, path, fileName, hashes, generateSourceMapVisualisations, opts); err != nil { - errs = append(errs, err) - } - <-sem - }() - } - } - return nil - }) - if err != nil { - errs = append(errs, err) - } - - wg.Wait() - - return changesFound, errs -} - -func openURL(w io.Writer, url string) error { - backoff := backoff.NewExponentialBackOff() - backoff.InitialInterval = time.Second - var client http.Client - client.Timeout = 1 * time.Second - for { - if _, err := client.Get(url); err == nil { - break - } - d := backoff.NextBackOff() - fmt.Fprintf(w, "Server not ready. Retrying in %v...\n", d) - time.Sleep(d) - } - return browser.OpenURL(url) -} - -// processSingleFile generates Go code for a single template. -// If a basePath is provided, the filename included in error messages is relative to it. -func processSingleFile(ctx context.Context, stdout io.Writer, basePath, fileName string, hashes map[string][sha256.Size]byte, generateSourceMapVisualisations bool, opts []generator.GenerateOpt) (err error) { - start := time.Now() - diag, err := generate(ctx, basePath, fileName, hashes, generateSourceMapVisualisations, opts) - if err != nil { - return err - } - var b bytes.Buffer - defer func() { - _, _ = b.WriteTo(stdout) - }() - if len(diag) > 0 { - logWarning(&b, "Generated code for %q in %s\n", fileName, time.Since(start)) - printDiagnostics(&b, fileName, diag) - return nil - } - logSuccess(&b, "Generated code for %q in %s\n", fileName, time.Since(start)) - return nil -} - -func printDiagnostics(w io.Writer, fileName string, diags []parser.Diagnostic) { - for _, d := range diags { - fmt.Fprint(w, "\t") - logWarning(w, "%s (%d:%d)\n", d.Message, d.Range.From.Line, d.Range.From.Col) - } - fmt.Fprintln(w) -} - -// generate Go code for a single template. -// If a basePath is provided, the filename included in error messages is relative to it. -func generate(ctx context.Context, basePath, fileName string, hashes map[string][sha256.Size]byte, generateSourceMapVisualisations bool, opts []generator.GenerateOpt) (diagnostics []parser.Diagnostic, err error) { - if err = ctx.Err(); err != nil { - return - } - - if hashes == nil { - hashes = make(map[string][sha256.Size]byte) - } - - t, err := parser.Parse(fileName) - if err != nil { - return nil, fmt.Errorf("%s parsing error: %w", fileName, err) - } - targetFileName := strings.TrimSuffix(fileName, ".templ") + "_templ.go" - - // Only use relative filenames to the basepath for filenames in runtime error messages. - errorMessageFileName := fileName - if basePath != "" { - errorMessageFileName, _ = filepath.Rel(basePath, fileName) - } - - var b bytes.Buffer - sourceMap, literals, err := generator.Generate(t, &b, append(opts, generator.WithFileName(errorMessageFileName))...) - if err != nil { - return nil, fmt.Errorf("%s generation error: %w", fileName, err) - } - - formattedGoCode, err := format.Source(b.Bytes()) - if err != nil { - return nil, fmt.Errorf("%s source formatting error: %w", fileName, err) - } - - // Hash output, and write out the file if the goCodeHash has changed. - goCodeHash := sha256.Sum256(formattedGoCode) - if hashes[targetFileName] != goCodeHash { - if err = os.WriteFile(targetFileName, formattedGoCode, 0o644); err != nil { - return nil, fmt.Errorf("failed to write target file %q: %w", targetFileName, err) - } - hashes[targetFileName] = goCodeHash - } - - // Add the txt file if it has changed. - if len(literals) > 0 { - txtFileName := strings.TrimSuffix(fileName, ".templ") + "_templ.txt" - txtHash := sha256.Sum256([]byte(literals)) - if hashes[txtFileName] != txtHash { - if err = os.WriteFile(txtFileName, []byte(literals), 0o644); err != nil { - return nil, fmt.Errorf("failed to write string literal file %q: %w", txtFileName, err) - } - hashes[txtFileName] = txtHash - } - } - - if generateSourceMapVisualisations { - err = generateSourceMapVisualisation(ctx, fileName, targetFileName, sourceMap) - } - return t.Diagnostics, err -} - -func generateSourceMapVisualisation(ctx context.Context, templFileName, goFileName string, sourceMap *parser.SourceMap) error { - if err := ctx.Err(); err != nil { - return err - } - var templContents, goContents []byte - var templErr, goErr error - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - templContents, templErr = os.ReadFile(templFileName) - }() - go func() { - defer wg.Done() - goContents, goErr = os.ReadFile(goFileName) - }() - wg.Wait() - if templErr != nil { - return templErr - } - if goErr != nil { - return templErr - } - - targetFileName := strings.TrimSuffix(templFileName, ".templ") + "_templ_sourcemap.html" - w, err := os.Create(targetFileName) - if err != nil { - return fmt.Errorf("%s sourcemap visualisation error: %w", templFileName, err) - } - defer w.Close() - b := bufio.NewWriter(w) - defer b.Flush() - - return visualize.HTML(templFileName, string(templContents), string(goContents), sourceMap).Render(ctx, b) -} - -func logError(w io.Writer, format string, a ...any) { - logWithDecoration(w, "✗", color.FgRed, format, a...) -} - -func logWarning(w io.Writer, format string, a ...any) { - logWithDecoration(w, "!", color.FgYellow, format, a...) -} - -func logSuccess(w io.Writer, format string, a ...any) { - logWithDecoration(w, "✓", color.FgGreen, format, a...) -} - -func logWithDecoration(w io.Writer, decoration string, col color.Attribute, format string, a ...any) { - color.New(col).Fprintf(w, "(%s) ", decoration) - fmt.Fprintf(w, format, a...) +func Run(ctx context.Context, stderr io.Writer, args Arguments) (err error) { + level := slog.LevelInfo.Level() + switch args.Level { + case "debug": + level = slog.LevelDebug.Level() + case "warn": + level = slog.LevelWarn.Level() + case "error": + level = slog.LevelError.Level() + } + // The built-in attributes with keys "time", "level", "source", and "msg" + // are passed to this function, except that time is omitted + // if zero, and source is omitted if AddSource is false. + log := slog.New(sloghandler.NewHandler(stderr, &slog.HandlerOptions{ + AddSource: args.Level == "debug", + Level: level, + })) + return NewGenerate(log, args).Run(ctx) } diff --git a/cmd/templ/generatecmd/sse/server.go b/cmd/templ/generatecmd/sse/server.go index c847ed982..fb7fe923d 100644 --- a/cmd/templ/generatecmd/sse/server.go +++ b/cmd/templ/generatecmd/sse/server.go @@ -50,8 +50,10 @@ func (s *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("Connection", "keep-alive") id := atomic.AddInt64(&s.counter, 1) + s.m.Lock() events := make(chan event) s.requests[id] = events + s.m.Unlock() defer func() { s.m.Lock() defer s.m.Unlock() @@ -70,7 +72,6 @@ loop: } timer.Reset(time.Second * 5) case e := <-events: - fmt.Println("Sending reload event...") if _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", e.Type, e.Data); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return diff --git a/cmd/templ/generatecmd/watcher/watch.go b/cmd/templ/generatecmd/watcher/watch.go new file mode 100644 index 000000000..297da8d85 --- /dev/null +++ b/cmd/templ/generatecmd/watcher/watch.go @@ -0,0 +1,129 @@ +package watcher + +import ( + "context" + "os" + "path" + "path/filepath" + "strings" + + "github.com/fsnotify/fsnotify" +) + +func Recursive(ctx context.Context, path string, out chan fsnotify.Event, errors chan error) (w *RecursiveWatcher, err error) { + fsnw, err := fsnotify.NewWatcher() + if err != nil { + return nil, err + } + w = &RecursiveWatcher{ + ctx: ctx, + w: fsnw, + Events: out, + Errors: errors, + } + go w.loop() + return w, w.Add(path) +} + +// WalkFiles walks the file tree rooted at path, sending a Create event for each +// file it encounters. +func WalkFiles(ctx context.Context, path string, out chan fsnotify.Event) (err error) { + return filepath.WalkDir(path, func(path string, info os.DirEntry, err error) error { + if err != nil { + return nil + } + if info.IsDir() && shouldSkipDir(path) { + return filepath.SkipDir + } + if !shouldIncludeFile(path) { + return nil + } + out <- fsnotify.Event{ + Name: path, + Op: fsnotify.Create, + } + return nil + }) +} + +func shouldIncludeFile(name string) bool { + if strings.HasSuffix(name, ".templ") { + return true + } + if strings.HasSuffix(name, "_templ.go") { + return true + } + if strings.HasSuffix(name, "_templ.txt") { + return true + } + return false +} + +type RecursiveWatcher struct { + ctx context.Context + w *fsnotify.Watcher + Events chan fsnotify.Event + Errors chan error +} + +func (w *RecursiveWatcher) Close() error { + return w.w.Close() +} + +func (w *RecursiveWatcher) loop() { + for { + select { + case <-w.ctx.Done(): + return + case event, ok := <-w.w.Events: + if !ok { + return + } + if event.Has(fsnotify.Create) { + if err := w.Add(event.Name); err != nil { + w.Errors <- err + } + } + // Only notify on templ related files. + if !shouldIncludeFile(event.Name) { + continue + } + w.Events <- event + case err, ok := <-w.w.Errors: + if !ok { + return + } + w.Errors <- err + } + } +} + +func (w *RecursiveWatcher) Add(dir string) error { + return filepath.WalkDir(dir, func(dir string, info os.DirEntry, err error) error { + if err != nil { + return nil + } + if !info.IsDir() { + return nil + } + if shouldSkipDir(dir) { + return filepath.SkipDir + } + return w.w.Add(dir) + }) +} + +func shouldSkipDir(dir string) bool { + if dir == "." { + return false + } + if dir == "vendor" || dir == "node_modules" { + return true + } + _, name := path.Split(dir) + // These directories are ignored by the Go tool. + if strings.HasPrefix(name, ".") || strings.HasPrefix(name, "_") { + return true + } + return false +} diff --git a/cmd/templ/main.go b/cmd/templ/main.go index 43a7e5ee6..21db70933 100644 --- a/cmd/templ/main.go +++ b/cmd/templ/main.go @@ -92,6 +92,8 @@ Args: Port to run the pprof server on. -keep-orphaned-files Keeps orphaned generated templ files. (default false) + -level + Log verbosity level. (default "info") -help Print help and exit. @@ -126,6 +128,7 @@ func generateCmd(w io.Writer, args []string) (code int) { workerCountFlag := cmd.Int("w", runtime.NumCPU(), "") pprofPortFlag := cmd.Int("pprof", 0, "") keepOrphanedFilesFlag := cmd.Bool("keep-orphaned-files", false, "") + levelFlag := cmd.String("level", "info", "") helpFlag := cmd.Bool("help", false, "") err := cmd.Parse(args) if err != nil || *helpFlag { @@ -153,6 +156,7 @@ func generateCmd(w io.Writer, args []string) (code int) { GenerateSourceMapVisualisations: *sourceMapVisualisationsFlag, IncludeVersion: *includeVersionFlag, IncludeTimestamp: *includeTimestampFlag, + Level: *levelFlag, PPROFPort: *pprofPortFlag, KeepOrphanedFiles: *keepOrphanedFilesFlag, }) diff --git a/cmd/templ/sloghandler/handler.go b/cmd/templ/sloghandler/handler.go new file mode 100644 index 000000000..289405d84 --- /dev/null +++ b/cmd/templ/sloghandler/handler.go @@ -0,0 +1,101 @@ +package sloghandler + +import ( + "context" + "io" + "log/slog" + "strings" + "sync" + + "github.com/fatih/color" +) + +var _ slog.Handler = &Handler{} + +type Handler struct { + h slog.Handler + m *sync.Mutex + w io.Writer +} + +var levelToIcon = map[slog.Level]string{ + slog.LevelDebug: "(✓)", + slog.LevelInfo: "(✓)", + slog.LevelWarn: "(!)", + slog.LevelError: "(✗)", +} +var levelToColor = map[slog.Level]*color.Color{ + slog.LevelDebug: color.New(color.FgCyan), + slog.LevelInfo: color.New(color.FgGreen), + slog.LevelWarn: color.New(color.FgYellow), + slog.LevelError: color.New(color.FgRed), +} + +func NewHandler(w io.Writer, opts *slog.HandlerOptions) *Handler { + if opts == nil { + opts = &slog.HandlerOptions{} + } + return &Handler{ + w: w, + h: slog.NewTextHandler(w, &slog.HandlerOptions{ + Level: opts.Level, + AddSource: opts.AddSource, + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + if opts.ReplaceAttr != nil { + a = opts.ReplaceAttr(groups, a) + } + if a.Key == slog.LevelKey { + level, ok := levelToIcon[a.Value.Any().(slog.Level)] + if !ok { + level = a.Value.Any().(slog.Level).String() + } + a.Value = slog.StringValue(level) + return a + } + if a.Key == slog.TimeKey { + return slog.Attr{} + } + return a + }, + }), + m: &sync.Mutex{}, + } +} + +func (h *Handler) Enabled(ctx context.Context, level slog.Level) bool { + return h.h.Enabled(ctx, level) +} + +func (h *Handler) WithAttrs(attrs []slog.Attr) slog.Handler { + return &Handler{h: h.h.WithAttrs(attrs), w: h.w, m: h.m} +} + +func (h *Handler) WithGroup(name string) slog.Handler { + return &Handler{h: h.h.WithGroup(name), w: h.w, m: h.m} +} + +var keyValueColor = color.New(color.Faint & color.FgBlack) + +func (h *Handler) Handle(ctx context.Context, r slog.Record) (err error) { + var sb strings.Builder + + sb.WriteString(levelToColor[r.Level].Sprint(levelToIcon[r.Level])) + sb.WriteString(" ") + sb.WriteString(r.Message) + + if r.NumAttrs() != 0 { + sb.WriteString(" [") + r.Attrs(func(a slog.Attr) bool { + sb.WriteString(keyValueColor.Sprintf(" %s=%s", a.Key, a.Value.String())) + return true + }) + sb.WriteString(" ]") + } + + sb.WriteString("\n") + + h.m.Lock() + defer h.m.Unlock() + _, err = io.WriteString(h.w, sb.String()) + return err +} diff --git a/examples/counter-basic/main.go b/examples/counter-basic/main.go index 4c9cb95bb..6c10367ee 100644 --- a/examples/counter-basic/main.go +++ b/examples/counter-basic/main.go @@ -63,7 +63,7 @@ func main() { // Start the server. fmt.Println("listening on :8080") - if err := http.ListenAndServe(":8080", muxWithSessionMiddleware); err != nil { + if err := http.ListenAndServe("127.0.0.1:8080", muxWithSessionMiddleware); err != nil { log.Printf("error listening: %v", err) } } diff --git a/flake.nix b/flake.nix index 69b11da0c..76070f7a8 100644 --- a/flake.nix +++ b/flake.nix @@ -34,7 +34,7 @@ name = "templ"; src = gitignore.lib.gitignoreSource ./.; subPackages = [ "cmd/templ" ]; - vendorHash = "sha256-Bk895ApJhpIHmQ5hApgdJRAJVZn3PUGJoNO1T7rIPz0="; + vendorHash = "sha256-U/KFUGi47dSE1YxKWOUlxvUR1BKI1snRjZlXZ8hY24c="; CGO_ENABLED = 0; flags = [ "-trimpath" diff --git a/go.mod b/go.mod index 96ebb82fa..83c69e7b6 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/cenkalti/backoff/v4 v4.2.1 github.com/cli/browser v1.2.0 github.com/fatih/color v1.16.0 + github.com/fsnotify/fsnotify v1.7.0 github.com/google/go-cmp v0.6.0 github.com/natefinch/atomic v1.0.1 github.com/rs/cors v1.8.3 diff --git a/go.sum b/go.sum index 1eaaf6470..17b54d767 100644 --- a/go.sum +++ b/go.sum @@ -20,6 +20,8 @@ github.com/cli/browser v1.2.0/go.mod h1:xFFnXLVcAyW9ni0cuo6NnrbCP75JxJ0RO7VtCBiH github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=