diff --git a/cmd/templ/generatecmd/watcher/watch.go b/cmd/templ/generatecmd/watcher/watch.go index 297da8d85..57d725f2a 100644 --- a/cmd/templ/generatecmd/watcher/watch.go +++ b/cmd/templ/generatecmd/watcher/watch.go @@ -6,11 +6,18 @@ import ( "path" "path/filepath" "strings" + "sync" + "time" "github.com/fsnotify/fsnotify" ) -func Recursive(ctx context.Context, path string, out chan fsnotify.Event, errors chan error) (w *RecursiveWatcher, err error) { +func Recursive( + ctx context.Context, + path string, + out chan fsnotify.Event, + errors chan error, +) (w *RecursiveWatcher, err error) { fsnw, err := fsnotify.NewWatcher() if err != nil { return nil, err @@ -20,6 +27,7 @@ func Recursive(ctx context.Context, path string, out chan fsnotify.Event, errors w: fsnw, Events: out, Errors: errors, + timers: make(map[timerKey]*time.Timer), } go w.loop() return w, w.Add(path) @@ -60,10 +68,24 @@ func shouldIncludeFile(name string) bool { } type RecursiveWatcher struct { - ctx context.Context - w *fsnotify.Watcher - Events chan fsnotify.Event - Errors chan error + ctx context.Context + w *fsnotify.Watcher + Events chan fsnotify.Event + Errors chan error + timerMu sync.Mutex + timers map[timerKey]*time.Timer +} + +type timerKey struct { + name string + op fsnotify.Op +} + +func timerKeyFromEvent(event fsnotify.Event) timerKey { + return timerKey{ + name: event.Name, + op: event.Op, + } } func (w *RecursiveWatcher) Close() error { @@ -88,7 +110,20 @@ func (w *RecursiveWatcher) loop() { if !shouldIncludeFile(event.Name) { continue } - w.Events <- event + tk := timerKeyFromEvent(event) + w.timerMu.Lock() + t, ok := w.timers[tk] + w.timerMu.Unlock() + if !ok { + t = time.AfterFunc(100*time.Millisecond, func() { + w.Events <- event + }) + w.timerMu.Lock() + w.timers[tk] = t + w.timerMu.Unlock() + continue + } + t.Reset(100 * time.Millisecond) case err, ok := <-w.w.Errors: if !ok { return diff --git a/cmd/templ/generatecmd/watcher/watch_test.go b/cmd/templ/generatecmd/watcher/watch_test.go new file mode 100644 index 000000000..e90180da7 --- /dev/null +++ b/cmd/templ/generatecmd/watcher/watch_test.go @@ -0,0 +1,125 @@ +package watcher + +import ( + "context" + "testing" + "time" + + "github.com/fsnotify/fsnotify" +) + +func TestWatchDebouncesDuplicates(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + rw := &RecursiveWatcher{ + ctx: ctx, + w: &fsnotify.Watcher{ + Events: make(chan fsnotify.Event), + }, + Events: make(chan fsnotify.Event, 2), + timers: make(map[timerKey]*time.Timer), + } + go func() { + rw.w.Events <- fsnotify.Event{Name: "test.templ"} + rw.w.Events <- fsnotify.Event{Name: "test.templ"} + cancel() + close(rw.w.Events) + }() + rw.loop() + count := 0 + exp := time.After(300 * time.Millisecond) + for { + select { + case <-rw.Events: + count++ + case <-exp: + if count != 1 { + t.Errorf("expected 1 event, got %d", count) + } + return + } + } +} + +func TestWatchDoesNotDebounceDifferentEvents(t *testing.T) { + tests := []struct { + event1 fsnotify.Event + event2 fsnotify.Event + }{ + // Different files + {fsnotify.Event{Name: "test.templ"}, fsnotify.Event{Name: "test2.templ"}}, + // Different operations + { + fsnotify.Event{Name: "test.templ", Op: fsnotify.Create}, + fsnotify.Event{Name: "test.templ", Op: fsnotify.Write}, + }, + // Different operations and files + { + fsnotify.Event{Name: "test.templ", Op: fsnotify.Create}, + fsnotify.Event{Name: "test2.templ", Op: fsnotify.Write}, + }, + } + for _, test := range tests { + ctx, cancel := context.WithCancel(context.Background()) + rw := &RecursiveWatcher{ + ctx: ctx, + w: &fsnotify.Watcher{ + Events: make(chan fsnotify.Event), + }, + Events: make(chan fsnotify.Event, 2), + timers: make(map[timerKey]*time.Timer), + } + go func() { + rw.w.Events <- test.event1 + rw.w.Events <- test.event2 + cancel() + close(rw.w.Events) + }() + rw.loop() + count := 0 + exp := time.After(300 * time.Millisecond) + for { + select { + case <-rw.Events: + count++ + case <-exp: + if count != 2 { + t.Errorf("expected 2 event, got %d", count) + } + return + } + } + } +} + +func TestWatchDoesNotDebounceSeparateEvents(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + rw := &RecursiveWatcher{ + ctx: ctx, + w: &fsnotify.Watcher{ + Events: make(chan fsnotify.Event), + }, + Events: make(chan fsnotify.Event, 2), + timers: make(map[timerKey]*time.Timer), + } + go func() { + rw.w.Events <- fsnotify.Event{Name: "test.templ"} + <-time.After(200 * time.Millisecond) + rw.w.Events <- fsnotify.Event{Name: "test.templ"} + cancel() + close(rw.w.Events) + }() + rw.loop() + count := 0 + exp := time.After(500 * time.Millisecond) + for { + select { + case <-rw.Events: + count++ + case <-exp: + if count != 2 { + t.Errorf("expected 2 event, got %d", count) + } + return + } + } +}