Skip to content

Commit

Permalink
cli: watch imported files
Browse files Browse the repository at this point in the history
  • Loading branch information
alixander committed Nov 10, 2023
1 parent 4c091f5 commit cd30bd5
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 43 deletions.
6 changes: 4 additions & 2 deletions d2cli/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
"io/fs"
"os"
"os/exec"
"os/user"
Expand Down Expand Up @@ -332,7 +333,7 @@ func Run(ctx context.Context, ms *xmain.State) (err error) {
ctx, cancel := timelib.WithTimeout(ctx, time.Minute*2)
defer cancel()

_, written, err := compile(ctx, ms, plugins, layoutFlag, renderOpts, fontFamily, *animateIntervalFlag, inputPath, outputPath, "", *bundleFlag, *forceAppendixFlag, pw.Page)
_, written, err := compile(ctx, ms, plugins, nil, layoutFlag, renderOpts, fontFamily, *animateIntervalFlag, inputPath, outputPath, "", *bundleFlag, *forceAppendixFlag, pw.Page)
if err != nil {
if written {
return fmt.Errorf("failed to fully compile (partial render written) %s: %w", ms.HumanPath(inputPath), err)
Expand Down Expand Up @@ -367,7 +368,7 @@ func LayoutResolver(ctx context.Context, ms *xmain.State, plugins []d2plugin.Plu
}
}

func compile(ctx context.Context, ms *xmain.State, plugins []d2plugin.Plugin, layout *string, renderOpts d2svg.RenderOpts, fontFamily *d2fonts.FontFamily, animateInterval int64, inputPath, outputPath, boardPath string, bundle, forceAppendix bool, page playwright.Page) (_ []byte, written bool, _ error) {
func compile(ctx context.Context, ms *xmain.State, plugins []d2plugin.Plugin, fs fs.FS, layout *string, renderOpts d2svg.RenderOpts, fontFamily *d2fonts.FontFamily, animateInterval int64, inputPath, outputPath, boardPath string, bundle, forceAppendix bool, page playwright.Page) (_ []byte, written bool, _ error) {
start := time.Now()
input, err := ms.ReadPath(inputPath)
if err != nil {
Expand All @@ -385,6 +386,7 @@ func compile(ctx context.Context, ms *xmain.State, plugins []d2plugin.Plugin, la
InputPath: inputPath,
LayoutResolver: LayoutResolver(ctx, ms, plugins),
Layout: layout,
FS: fs,
}

cancel := background.Repeat(func() {
Expand Down
112 changes: 94 additions & 18 deletions d2cli/watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"os"
"path/filepath"
"runtime"
"sort"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -218,10 +219,13 @@ func (w *watcher) goFunc(fn func(context.Context) error) {
* TODO: Abstract out file system and fsnotify to test this with 100% coverage. See comment in main_test.go
*/
func (w *watcher) watchLoop(ctx context.Context) error {
lastModified, err := w.ensureAddWatch(ctx)
lastModified := make(map[string]time.Time)

mt, err := w.ensureAddWatch(ctx, w.inputPath)
if err != nil {
return err
}
lastModified[w.inputPath] = mt
w.ms.Log.Info.Printf("compiling %v...", w.ms.HumanPath(w.inputPath))
w.requestCompile()

Expand All @@ -230,40 +234,49 @@ func (w *watcher) watchLoop(ctx context.Context) error {
pollTicker := time.NewTicker(time.Second * 10)
defer pollTicker.Stop()

changed := make(map[string]struct{})

for {
select {
case <-pollTicker.C:
// In case we missed an event indicating the path is unwatchable and we won't be
// getting any more events.
// File notification APIs are notoriously unreliable. I've personally experienced
// many quirks and so feel this check is justified even if excessive.
mt, err := w.ensureAddWatch(ctx)
if err != nil {
return err
missedChanges := false
for _, watched := range w.fw.WatchList() {
mt, err := w.ensureAddWatch(ctx, watched)
if err != nil {
return err
}
if mt2, ok := lastModified[watched]; !ok || !mt.Equal(mt2) {
missedChanges = true
lastModified[watched] = mt
w.requestCompile()
}
}
if !mt.Equal(lastModified) {
// We missed changes.
lastModified = mt
if missedChanges {
w.requestCompile()
}
case ev, ok := <-w.fw.Events:
if !ok {
return errors.New("fsnotify watcher closed")
}
w.ms.Log.Debug.Printf("received file system event %v", ev)
mt, err := w.ensureAddWatch(ctx)
mt, err := w.ensureAddWatch(ctx, ev.Name)
if err != nil {
return err
}
if ev.Op == fsnotify.Chmod {
if mt.Equal(lastModified) {
if mt.Equal(lastModified[ev.Name]) {
// Benign Chmod.
// See https://github.com/fsnotify/fsnotify/issues/15
continue
}
// We missed changes.
lastModified = mt
lastModified[ev.Name] = mt
}
changed[ev.Name] = struct{}{}
// The purpose of eatBurstTimer is to wait at least 16 milliseconds after a sequence of
// events to ensure that whomever is editing the file is now done.
//
Expand All @@ -276,8 +289,18 @@ func (w *watcher) watchLoop(ctx context.Context) error {
// misleading error.
eatBurstTimer.Reset(time.Millisecond * 16)
case <-eatBurstTimer.C:
w.ms.Log.Info.Printf("detected change in %v: recompiling...", w.ms.HumanPath(w.inputPath))
var changedList []string
for k := range changed {
changedList = append(changedList, k)
}
sort.Strings(changedList)
changedStr := w.ms.HumanPath(changedList[0])
for i := 1; i < len(changed); i++ {
changedStr += fmt.Sprintf(", %s", w.ms.HumanPath(changedList[i]))
}
w.ms.Log.Info.Printf("detected change in %s: recompiling...", changedStr)
w.requestCompile()
changed = make(map[string]struct{})
case err, ok := <-w.fw.Errors:
if !ok {
return errors.New("fsnotify watcher closed")
Expand All @@ -296,17 +319,17 @@ func (w *watcher) requestCompile() {
}
}

func (w *watcher) ensureAddWatch(ctx context.Context) (time.Time, error) {
func (w *watcher) ensureAddWatch(ctx context.Context, path string) (time.Time, error) {
interval := time.Millisecond * 16
tc := time.NewTimer(0)
<-tc.C
for {
mt, err := w.addWatch(ctx)
mt, err := w.addWatch(ctx, path)
if err == nil {
return mt, nil
}
if interval >= time.Second {
w.ms.Log.Error.Printf("failed to watch inputPath %q: %v (retrying in %v)", w.ms.HumanPath(w.inputPath), err, interval)
w.ms.Log.Error.Printf("failed to watch %q: %v (retrying in %v)", w.ms.HumanPath(path), err, interval)
}

tc.Reset(interval)
Expand All @@ -324,19 +347,56 @@ func (w *watcher) ensureAddWatch(ctx context.Context) (time.Time, error) {
}
}

func (w *watcher) addWatch(ctx context.Context) (time.Time, error) {
err := w.fw.Add(w.inputPath)
func (w *watcher) addWatch(ctx context.Context, path string) (time.Time, error) {
err := w.fw.Add(path)
if err != nil {
return time.Time{}, err
}
var d os.FileInfo
d, err = os.Stat(w.inputPath)
d, err = os.Stat(path)
if err != nil {
return time.Time{}, err
}
return d.ModTime(), nil
}

func (w *watcher) replaceWatchList(ctx context.Context, paths []string) error {
// First remove the files no longer being watched
for _, watched := range w.fw.WatchList() {
if watched == w.inputPath {
continue
}
found := false
for _, p := range paths {
if watched == p {
found = true
break
}
}
if !found {
// Don't mind errors here
w.fw.Remove(watched)
}
}
// Then add the files newly being watched
for _, p := range paths {
found := false
for _, watched := range w.fw.WatchList() {
if watched == p {
found = true
break
}
}
if !found {
_, err := w.ensureAddWatch(ctx, p)
if err != nil {
return err
}
}
}
return nil
}

func (w *watcher) compileLoop(ctx context.Context) error {
firstCompile := true
for {
Expand Down Expand Up @@ -364,7 +424,8 @@ func (w *watcher) compileLoop(ctx context.Context) error {
w.pw = newPW
}

svg, _, err := compile(ctx, w.ms, w.plugins, w.layout, w.renderOpts, w.fontFamily, w.animateInterval, w.inputPath, w.outputPath, w.boardPath, w.bundle, w.forceAppendix, w.pw.Page)
fs := trackedFS{}
svg, _, err := compile(ctx, w.ms, w.plugins, &fs, w.layout, w.renderOpts, w.fontFamily, w.animateInterval, w.inputPath, w.outputPath, w.boardPath, w.bundle, w.forceAppendix, w.pw.Page)
errs := ""
if err != nil {
if len(svg) > 0 {
Expand All @@ -375,6 +436,11 @@ func (w *watcher) compileLoop(ctx context.Context) error {
errs = err.Error()
w.ms.Log.Error.Print(errs)
}
err = w.replaceWatchList(ctx, fs.opened)
if err != nil {
return err
}

w.broadcast(&compileResult{
SVG: string(svg),
Scale: w.renderOpts.Scale,
Expand Down Expand Up @@ -574,3 +640,13 @@ func wsHeartbeat(ctx context.Context, c *websocket.Conn) {
}
}
}

// trackedFS is OS's FS with the addition that it tracks which files are opened
type trackedFS struct {
opened []string
}

func (tfs *trackedFS) Open(name string) (fs.File, error) {
tfs.opened = append(tfs.opened, name)
return os.Open(name)
}
Loading

0 comments on commit cd30bd5

Please sign in to comment.