diff --git a/cmd/litestream/main.go b/cmd/litestream/main.go index 3373eca4..fa43f4dc 100644 --- a/cmd/litestream/main.go +++ b/cmd/litestream/main.go @@ -267,7 +267,6 @@ func ReadConfigFile(filename string, expandEnv bool) (_ Config, err error) { // DBConfig represents the configuration for a single database. type DBConfig struct { Path string `yaml:"path"` - MonitorInterval *time.Duration `yaml:"monitor-interval"` CheckpointInterval *time.Duration `yaml:"checkpoint-interval"` MinCheckpointPageN *int `yaml:"min-checkpoint-page-count"` MaxCheckpointPageN *int `yaml:"max-checkpoint-page-count"` @@ -281,14 +280,15 @@ func NewDBFromConfig(dbc *DBConfig) (*litestream.DB, error) { if err != nil { return nil, err } + return NewDBFromConfigWithPath(dbc, path) +} +// NewDBFromConfigWithPath instantiates a DB based on a configuration and using a given path. +func NewDBFromConfigWithPath(dbc *DBConfig, path string) (*litestream.DB, error) { // Initialize database with given path. db := litestream.NewDB(path) // Override default database settings if specified in configuration. - if dbc.MonitorInterval != nil { - db.MonitorInterval = *dbc.MonitorInterval - } if dbc.CheckpointInterval != nil { db.CheckpointInterval = *dbc.CheckpointInterval } diff --git a/cmd/litestream/replicate.go b/cmd/litestream/replicate.go index fa849073..e0ae7bd1 100644 --- a/cmd/litestream/replicate.go +++ b/cmd/litestream/replicate.go @@ -35,8 +35,7 @@ type ReplicateCommand struct { Config Config - // List of managed databases specified in the config. - DBs []*litestream.DB + server *litestream.Server } // NewReplicateCommand returns a new instance of ReplicateCommand. @@ -104,21 +103,27 @@ func (c *ReplicateCommand) Run(ctx context.Context) (err error) { log.Println("no databases specified in configuration") } + c.server = litestream.NewServer() + if err := c.server.Open(); err != nil { + return fmt.Errorf("open server: %w", err) + } + + // Add databases to the server. for _, dbConfig := range c.Config.DBs { - db, err := NewDBFromConfig(dbConfig) + path, err := expand(dbConfig.Path) if err != nil { return err } - // Open database & attach to program. - if err := db.Open(); err != nil { + if err := c.server.Watch(path, func(path string) (*litestream.DB, error) { + return NewDBFromConfigWithPath(dbConfig, path) + }); err != nil { return err } - c.DBs = append(c.DBs, db) } // Notify user that initialization is done. - for _, db := range c.DBs { + for _, db := range c.server.DBs() { log.Printf("initialized db: %s", db.Path()) for _, r := range db.Replicas { switch client := r.Client().(type) { @@ -180,13 +185,8 @@ func (c *ReplicateCommand) Run(ctx context.Context) (err error) { // Close closes all open databases. func (c *ReplicateCommand) Close() (err error) { - for _, db := range c.DBs { - if e := db.Close(); e != nil { - log.Printf("error closing db: path=%s err=%s", db.Path(), e) - if err == nil { - err = e - } - } + if e := c.server.Close(); e != nil && err == nil { + err = e } return err } diff --git a/db.go b/db.go index f73cf0ee..84acd616 100644 --- a/db.go +++ b/db.go @@ -28,12 +28,15 @@ import ( // Default DB settings. const ( - DefaultMonitorInterval = 1 * time.Second DefaultCheckpointInterval = 1 * time.Minute DefaultMinCheckpointPageN = 1000 DefaultMaxCheckpointPageN = 10000 ) +// MonitorDelayInterval is the time Litestream will wait after receiving a file +// change notification before processing the WAL file for changes. +const MonitorDelayInterval = 100 * time.Millisecond + // MaxIndex is the maximum possible WAL index. // If this index is reached then a new generation will be started. const MaxIndex = 0x7FFFFFFF @@ -43,14 +46,15 @@ const BusyTimeout = 1 * time.Second // DB represents a managed instance of a SQLite database in the file system. type DB struct { - mu sync.RWMutex - path string // part to database - db *sql.DB // target database - f *os.File // long-running db file descriptor - rtx *sql.Tx // long running read transaction - pos Pos // cached position - pageSize int // page size, in bytes - notify chan struct{} // closes on WAL change + mu sync.RWMutex + path string // part to database + db *sql.DB // target database + f *os.File // long-running db file descriptor + rtx *sql.Tx // long running read transaction + pos Pos // cached position + pageSize int // page size, in bytes + notifyCh chan struct{} // notifies DB of changes + walNotify chan struct{} // closes on WAL change // Cached salt & checksum from current shadow header. hdr []byte @@ -98,9 +102,6 @@ type DB struct { // better precision. CheckpointInterval time.Duration - // Frequency at which to perform db sync. - MonitorInterval time.Duration - // List of replicas for the database. // Must be set before calling Open(). Replicas []*Replica @@ -111,13 +112,13 @@ type DB struct { // NewDB returns a new instance of DB for a given path. func NewDB(path string) *DB { db := &DB{ - path: path, - notify: make(chan struct{}), + path: path, + notifyCh: make(chan struct{}, 1), + walNotify: make(chan struct{}), MinCheckpointPageN: DefaultMinCheckpointPageN, MaxCheckpointPageN: DefaultMaxCheckpointPageN, CheckpointInterval: DefaultCheckpointInterval, - MonitorInterval: DefaultMonitorInterval, Logger: log.New(LogWriter, fmt.Sprintf("%s: ", logPrefixPath(path)), LogFlags), } @@ -358,11 +359,16 @@ func (db *DB) walSegmentOffsetsByIndex(generation string, index int) ([]int64, e return offsets, nil } -// Notify returns a channel that closes when the shadow WAL changes. -func (db *DB) Notify() <-chan struct{} { +// NotifyCh returns a channel that can be used to signal changes in the DB. +func (db *DB) NotifyCh() chan<- struct{} { + return db.notifyCh +} + +// WALNotify returns a channel that closes when the shadow WAL changes. +func (db *DB) WALNotify() <-chan struct{} { db.mu.RLock() defer db.mu.RUnlock() - return db.notify + return db.walNotify } // PageSize returns the page size of the underlying database. @@ -395,10 +401,8 @@ func (db *DB) Open() (err error) { } // Start monitoring SQLite database in a separate goroutine. - if db.MonitorInterval > 0 { - db.wg.Add(1) - go func() { defer db.wg.Done(); db.monitor() }() - } + db.wg.Add(1) + go func() { defer db.wg.Done(); db.monitor() }() return nil } @@ -903,8 +907,8 @@ func (db *DB) Sync(ctx context.Context) (err error) { // Notify replicas of WAL changes. if db.pos != origPos { - close(db.notify) - db.notify = make(chan struct{}) + close(db.walNotify) + db.walNotify = make(chan struct{}) } return nil @@ -1367,18 +1371,27 @@ func (db *DB) execCheckpoint(mode string) (err error) { // monitor runs in a separate goroutine and monitors the database & WAL. func (db *DB) monitor() { - ticker := time.NewTicker(db.MonitorInterval) - defer ticker.Stop() + timer := time.NewTimer(MonitorDelayInterval) + defer timer.Stop() for { - // Wait for ticker or context close. + // Wait for a file change notification from the file system. select { case <-db.ctx.Done(): return - case <-ticker.C: + case <-db.notifyCh: + } + + // Wait for small delay before processing changes. + timer.Reset(MonitorDelayInterval) + <-timer.C + + // Clear any additional change notifications that occurred during delay. + select { + case <-db.notifyCh: + default: } - // Sync the database to the shadow WAL. if err := db.Sync(db.ctx); err != nil && !errors.Is(err, context.Canceled) { db.Logger.Printf("sync error: %s", err) } diff --git a/db_test.go b/db_test.go index a9dbb585..f424fc44 100644 --- a/db_test.go +++ b/db_test.go @@ -560,7 +560,6 @@ func MustOpenDB(tb testing.TB) *litestream.DB { func MustOpenDBAt(tb testing.TB, path string) *litestream.DB { tb.Helper() db := litestream.NewDB(path) - db.MonitorInterval = 0 // disable background goroutine if err := db.Open(); err != nil { tb.Fatal(err) } diff --git a/go.mod b/go.mod index 13575b62..a01cc98f 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/prometheus/client_golang v1.12.1 golang.org/x/crypto v0.0.0-20220112180741-5e0467b6c7ce golang.org/x/sync v0.0.0-20210220032951-036812b2e83c + golang.org/x/sys v0.0.0-20220204135822-1c1b9b1eba6a // indirect google.golang.org/api v0.66.0 gopkg.in/yaml.v2 v2.4.0 ) diff --git a/go.sum b/go.sum index f7218f02..cfbb6548 100644 --- a/go.sum +++ b/go.sum @@ -485,6 +485,8 @@ golang.org/x/sys v0.0.0-20211210111614-af8b64212486/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 h1:XfKQ4OlFl8okEOr5UvAqFRVj8pY/4yfcXrddB8qAbU0= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220204135822-1c1b9b1eba6a h1:ppl5mZgokTT8uPkmYOyEUmPTr3ypaKkg5eFOGrAmxxE= +golang.org/x/sys v0.0.0-20220204135822-1c1b9b1eba6a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/integration/cmd_test.go b/integration/cmd_test.go index 92d143b4..a663f9c2 100644 --- a/integration/cmd_test.go +++ b/integration/cmd_test.go @@ -43,6 +43,8 @@ func TestCmd_Replicate_OK(t *testing.T) { db, err := sql.Open("sqlite3", filepath.Join(tempDir, "db")) if err != nil { t.Fatal(err) + } else if _, err := db.ExecContext(ctx, `PRAGMA journal_mode = wal`); err != nil { + t.Fatal(err) } else if _, err := db.ExecContext(ctx, `CREATE TABLE t (id INTEGER PRIMARY KEY)`); err != nil { t.Fatal(err) } @@ -378,9 +380,9 @@ func waitForLogMessage(tb testing.TB, b *internal.LockingBuffer, msg string) { // killLitestreamCmd interrupts the process and waits for a clean shutdown. func killLitestreamCmd(tb testing.TB, cmd *exec.Cmd, stdout *internal.LockingBuffer) { if err := cmd.Process.Signal(os.Interrupt); err != nil { - tb.Fatal(err) + tb.Fatal("kill litestream: signal:", err) } else if err := cmd.Wait(); err != nil { - tb.Fatal(err) + tb.Fatal("kill litestream: cmd:", err) } } diff --git a/internal/file_watcher.go b/internal/file_watcher.go new file mode 100644 index 00000000..501703d4 --- /dev/null +++ b/internal/file_watcher.go @@ -0,0 +1,36 @@ +package internal + +import ( + "errors" +) + +// File event mask constants. +const ( + FileEventCreated = 1 << iota + FileEventModified + FileEventDeleted +) + +// FileEvent represents an event on a watched file. +type FileEvent struct { + Name string + Mask int +} + +// ErrFileEventQueueOverflow is returned when the file event queue has overflowed. +var ErrFileEventQueueOverflow = errors.New("file event queue overflow") + +// FileWatcher represents a watcher of file events. +type FileWatcher interface { + Open() error + Close() error + + // Returns a channel of events for watched files. + Events() <-chan FileEvent + + // Adds a specific file to be watched. + Watch(filename string) error + + // Removes a specific file from being watched. + Unwatch(filename string) error +} diff --git a/internal/file_watcher_bsd.go b/internal/file_watcher_bsd.go new file mode 100644 index 00000000..26737fd1 --- /dev/null +++ b/internal/file_watcher_bsd.go @@ -0,0 +1,258 @@ +//go:build freebsd || openbsd || netbsd || dragonfly || darwin + +package internal + +import ( + "context" + "log" + "os" + "path/filepath" + "sync" + "time" + + "golang.org/x/sync/errgroup" + "golang.org/x/sys/unix" +) + +var _ FileWatcher = (*KqueueFileWatcher)(nil) + +// KqueueFileWatcher watches files and is notified of events on them. +// +// Watcher code based on https://github.com/fsnotify/fsnotify +type KqueueFileWatcher struct { + fd int + events chan FileEvent + + mu sync.Mutex + watches map[string]int + paths map[int]string + notExists map[string]struct{} + + g errgroup.Group + ctx context.Context + cancel func() +} + +// NewKqueueFileWatcher returns a new instance of KqueueFileWatcher. +func NewKqueueFileWatcher() *KqueueFileWatcher { + return &KqueueFileWatcher{ + events: make(chan FileEvent), + + watches: make(map[string]int), + paths: make(map[int]string), + notExists: make(map[string]struct{}), + } +} + +// NewFileWatcher returns an instance of KqueueFileWatcher on BSD systems. +func NewFileWatcher() FileWatcher { + return NewKqueueFileWatcher() +} + +// Events returns a read-only channel of file events. +func (w *KqueueFileWatcher) Events() <-chan FileEvent { + return w.events +} + +// Open initializes the watcher and begins listening for file events. +func (w *KqueueFileWatcher) Open() (err error) { + if w.fd, err = unix.Kqueue(); err != nil { + return err + } + + w.ctx, w.cancel = context.WithCancel(context.Background()) + w.g.Go(func() error { + if err := w.monitor(w.ctx); err != nil && w.ctx.Err() == nil { + return err + } + return nil + }) + w.g.Go(func() error { + if err := w.monitorNotExists(w.ctx); err != nil && w.ctx.Err() == nil { + return err + } + return nil + }) + + return nil +} + +// Close stops watching for file events and cleans up resources. +func (w *KqueueFileWatcher) Close() (err error) { + w.cancel() + + if w.fd != 0 { + if e := unix.Close(w.fd); e != nil && err == nil { + err = e + } + } + + if e := w.g.Wait(); e != nil && err == nil { + err = e + } + return err +} + +// Watch begins watching the given file or directory. +func (w *KqueueFileWatcher) Watch(filename string) error { + w.mu.Lock() + defer w.mu.Unlock() + + filename = filepath.Clean(filename) + + // If file doesn't exist, monitor separately until it does exist as we + // can't watch non-existent files with kqueue. + if _, err := os.Stat(filename); os.IsNotExist(err) { + w.notExists[filename] = struct{}{} + return nil + } + + return w.addWatch(filename) +} + +func (w *KqueueFileWatcher) addWatch(filename string) error { + wd, err := unix.Open(filename, unix.O_NONBLOCK|unix.O_RDONLY|unix.O_CLOEXEC, 0700) + if err != nil { + return err + } + + // TODO: Handle return count different than 1. + kevent := unix.Kevent_t{Fflags: unix.NOTE_DELETE | unix.NOTE_WRITE} + unix.SetKevent(&kevent, wd, unix.EVFILT_VNODE, unix.EV_ADD|unix.EV_CLEAR|unix.EV_ENABLE) + if _, err := unix.Kevent(w.fd, []unix.Kevent_t{kevent}, nil, nil); err != nil { + return err + } + + w.watches[filename] = wd + w.paths[wd] = filename + + delete(w.notExists, filename) + + return err +} + +// Unwatch stops watching the given file or directory. +func (w *KqueueFileWatcher) Unwatch(filename string) error { + w.mu.Lock() + defer w.mu.Unlock() + + filename = filepath.Clean(filename) + + // Look up watch ID by filename. + wd, ok := w.watches[filename] + if !ok { + return nil + } + + // TODO: Handle return count different than 1. + var kevent unix.Kevent_t + unix.SetKevent(&kevent, wd, unix.EVFILT_VNODE, unix.EV_DELETE) + if _, err := unix.Kevent(w.fd, []unix.Kevent_t{kevent}, nil, nil); err != nil { + return err + } + unix.Close(wd) + + delete(w.paths, wd) + delete(w.watches, filename) + delete(w.notExists, filename) + + return nil +} + +// monitorNotExist runs in a separate goroutine and monitors for the creation of +// watched files that do not yet exist. +func (w *KqueueFileWatcher) monitorNotExists(ctx context.Context) error { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return nil + case <-ticker.C: + w.checkNotExists(ctx) + } + } +} + +func (w *KqueueFileWatcher) checkNotExists(ctx context.Context) { + w.mu.Lock() + defer w.mu.Unlock() + + for filename := range w.notExists { + if _, err := os.Stat(filename); os.IsNotExist(err) { + continue + } + + if err := w.addWatch(filename); err != nil { + log.Printf("non-existent file monitor: cannot add watch: %s", err) + } + + // Send event to channel. + select { + case w.events <- FileEvent{ + Name: filename, + Mask: FileEventCreated, + }: + default: + } + } +} + +// monitor runs in a separate goroutine and monitors the inotify event queue. +func (w *KqueueFileWatcher) monitor(ctx context.Context) error { + kevents := make([]unix.Kevent_t, 10) + timeout := unix.NsecToTimespec(int64(100 * time.Millisecond)) + + for { + n, err := unix.Kevent(w.fd, nil, kevents, &timeout) + if err != nil && err != unix.EINTR { + return err + } else if n < 0 { + continue + } + + for _, kevent := range kevents[:n] { + if err := w.recv(ctx, &kevent); err != nil { + return err + } + } + } +} + +// recv processes a single event from kqeueue. +func (w *KqueueFileWatcher) recv(ctx context.Context, kevent *unix.Kevent_t) error { + if err := ctx.Err(); err != nil { + return err + } + + // Look up filename & remove from watcher if this is a delete. + w.mu.Lock() + filename, ok := w.paths[int(kevent.Ident)] + if ok && kevent.Fflags&unix.NOTE_DELETE != 0 { + delete(w.paths, int(kevent.Ident)) + delete(w.watches, filename) + unix.Close(int(kevent.Ident)) + } + w.mu.Unlock() + + // Convert to generic file event mask. + var mask int + if kevent.Fflags&unix.NOTE_WRITE != 0 { + mask |= FileEventModified + } + if kevent.Fflags&unix.NOTE_DELETE != 0 { + mask |= FileEventDeleted + } + + // Send event to channel or wait for close. + select { + case <-ctx.Done(): + return ctx.Err() + case w.events <- FileEvent{ + Name: filename, + Mask: mask, + }: + return nil + } +} diff --git a/internal/file_watcher_linux.go b/internal/file_watcher_linux.go new file mode 100644 index 00000000..a337fd03 --- /dev/null +++ b/internal/file_watcher_linux.go @@ -0,0 +1,365 @@ +//go:build linux + +package internal + +import ( + "context" + "fmt" + "log" + "os" + "path/filepath" + "sync" + "time" + "unsafe" + + "golang.org/x/sync/errgroup" + "golang.org/x/sys/unix" +) + +var _ FileWatcher = (*InotifyFileWatcher)(nil) + +// InotifyFileWatcher watches files and is notified of events on them. +// +// Watcher code based on https://github.com/fsnotify/fsnotify +type InotifyFileWatcher struct { + inotify struct { + fd int + buf []byte + } + epoll struct { + fd int // epoll_create1() file descriptor + r int // read pipe file descriptor + w int // write pipe file descriptor + events []unix.EpollEvent + } + events chan FileEvent + + mu sync.Mutex + watches map[string]int + paths map[int]string + notExists map[string]struct{} + + g errgroup.Group + ctx context.Context + cancel func() +} + +// NewInotifyFileWatcher returns a new instance of InotifyFileWatcher. +func NewInotifyFileWatcher() *InotifyFileWatcher { + w := &InotifyFileWatcher{ + events: make(chan FileEvent), + + watches: make(map[string]int), + paths: make(map[int]string), + notExists: make(map[string]struct{}), + } + + w.inotify.buf = make([]byte, 4096*unix.SizeofInotifyEvent) + w.epoll.events = make([]unix.EpollEvent, 64) + + return w +} + +// NewFileWatcher returns an instance of InotifyFileWatcher on Linux systems. +func NewFileWatcher() FileWatcher { + return NewInotifyFileWatcher() +} + +// Events returns a read-only channel of file events. +func (w *InotifyFileWatcher) Events() <-chan FileEvent { + return w.events +} + +// Open initializes the watcher and begins listening for file events. +func (w *InotifyFileWatcher) Open() (err error) { + w.inotify.fd, err = unix.InotifyInit1(unix.IN_CLOEXEC) + if err != nil { + return fmt.Errorf("cannot init inotify: %w", err) + } + + // Initialize epoll and create a non-blocking pipe. + if w.epoll.fd, err = unix.EpollCreate1(unix.EPOLL_CLOEXEC); err != nil { + return fmt.Errorf("cannot create epoll: %w", err) + } + + pipe := []int{-1, -1} + if err := unix.Pipe2(pipe[:], unix.O_NONBLOCK|unix.O_CLOEXEC); err != nil { + return fmt.Errorf("cannot create epoll pipe: %w", err) + } + w.epoll.r, w.epoll.w = pipe[0], pipe[1] + + // Register inotify fd with epoll + if err := unix.EpollCtl(w.epoll.fd, unix.EPOLL_CTL_ADD, w.inotify.fd, &unix.EpollEvent{ + Fd: int32(w.inotify.fd), + Events: unix.EPOLLIN, + }); err != nil { + return fmt.Errorf("cannot add inotify to epoll: %w", err) + } + + // Register pipe fd with epoll + if err := unix.EpollCtl(w.epoll.fd, unix.EPOLL_CTL_ADD, w.epoll.r, &unix.EpollEvent{ + Fd: int32(w.epoll.r), + Events: unix.EPOLLIN, + }); err != nil { + return fmt.Errorf("cannot add pipe to epoll: %w", err) + } + + w.ctx, w.cancel = context.WithCancel(context.Background()) + w.g.Go(func() error { + if err := w.monitor(w.ctx); err != nil && w.ctx.Err() == nil { + return err + } + return nil + }) + w.g.Go(func() error { + if err := w.monitorNotExists(w.ctx); err != nil && w.ctx.Err() == nil { + return err + } + return nil + }) + + return nil +} + +// Close stops watching for file events and cleans up resources. +func (w *InotifyFileWatcher) Close() (err error) { + w.cancel() + + if e := w.wake(); e != nil && err == nil { + err = e + } + if e := w.g.Wait(); e != nil && err == nil { + err = e + } + return err +} + +// Watch begins watching the given file or directory. +func (w *InotifyFileWatcher) Watch(filename string) error { + w.mu.Lock() + defer w.mu.Unlock() + + filename = filepath.Clean(filename) + + // If file doesn't exist, monitor separately until it does exist as we + // can't watch non-existent files with inotify. + if _, err := os.Stat(filename); os.IsNotExist(err) { + w.notExists[filename] = struct{}{} + return nil + } + + return w.addWatch(filename) +} + +func (w *InotifyFileWatcher) addWatch(filename string) error { + wd, err := unix.InotifyAddWatch(w.inotify.fd, filename, unix.IN_MODIFY|unix.IN_DELETE_SELF) + if err != nil { + return err + } + + w.watches[filename] = wd + w.paths[wd] = filename + + delete(w.notExists, filename) + + return err +} + +// Unwatch stops watching the given file or directory. +func (w *InotifyFileWatcher) Unwatch(filename string) error { + w.mu.Lock() + defer w.mu.Unlock() + + filename = filepath.Clean(filename) + + // Look up watch ID by filename. + wd, ok := w.watches[filename] + if !ok { + return nil + } + + if _, err := unix.InotifyRmWatch(w.inotify.fd, uint32(wd)); err != nil { + return err + } + + delete(w.paths, wd) + delete(w.watches, filename) + delete(w.notExists, filename) + + return nil +} + +// monitorNotExist runs in a separate goroutine and monitors for the creation of +// watched files that do not yet exist. +func (w *InotifyFileWatcher) monitorNotExists(ctx context.Context) error { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return nil + case <-ticker.C: + w.checkNotExists(ctx) + } + } +} + +func (w *InotifyFileWatcher) checkNotExists(ctx context.Context) { + w.mu.Lock() + defer w.mu.Unlock() + + for filename := range w.notExists { + if _, err := os.Stat(filename); os.IsNotExist(err) { + continue + } + + if err := w.addWatch(filename); err != nil { + log.Printf("non-existent file monitor: cannot add watch: %s", err) + } + + // Send event to channel. + select { + case w.events <- FileEvent{ + Name: filename, + Mask: FileEventCreated, + }: + default: + } + } +} + +// monitor runs in a separate goroutine and monitors the inotify event queue. +func (w *InotifyFileWatcher) monitor(ctx context.Context) error { + // Close all file descriptors once monitor exits. + defer func() { + unix.Close(w.inotify.fd) + unix.Close(w.epoll.fd) + unix.Close(w.epoll.w) + unix.Close(w.epoll.r) + }() + + for { + if err := w.wait(ctx); err != nil { + return err + } else if err := w.read(ctx); err != nil { + return err + } + } +} + +// read reads from the inotify file descriptor. Automatically rety on EINTR. +func (w *InotifyFileWatcher) read(ctx context.Context) error { + for { + n, err := unix.Read(w.inotify.fd, w.inotify.buf) + if err != nil && err != unix.EINTR { + return err + } else if n < 0 { + continue + } + + return w.recv(ctx, w.inotify.buf[:n]) + } +} + +func (w *InotifyFileWatcher) recv(ctx context.Context, b []byte) error { + if err := ctx.Err(); err != nil { + return err + } + + for { + if len(b) == 0 { + return nil + } else if len(b) < unix.SizeofInotifyEvent { + return fmt.Errorf("InotifyFileWatcher.recv(): inotify short record: n=%d", len(b)) + } + + event := (*unix.InotifyEvent)(unsafe.Pointer(&b[0])) + if event.Mask&unix.IN_Q_OVERFLOW != 0 { + // TODO: Change to notify all watches. + return ErrFileEventQueueOverflow + } + + // Remove deleted files from the lookups. + w.mu.Lock() + name, ok := w.paths[int(event.Wd)] + if ok && event.Mask&unix.IN_DELETE_SELF != 0 { + delete(w.paths, int(event.Wd)) + delete(w.watches, name) + } + w.mu.Unlock() + + //if nameLen > 0 { + // // Point "bytes" at the first byte of the filename + // bytes := (*[unix.PathMax]byte)(unsafe.Pointer(&buf[offset+unix.SizeofInotifyEvent]))[:nameLen:nameLen] + // // The filename is padded with NULL bytes. TrimRight() gets rid of those. + // name += "/" + strings.TrimRight(string(bytes[0:nameLen]), "\000") + //} + + // Move to next event. + b = b[unix.SizeofInotifyEvent+event.Len:] + + // Skip event if ignored. + if event.Mask&unix.IN_IGNORED != 0 { + continue + } + + // Convert to generic file event mask. + var mask int + if event.Mask&unix.IN_MODIFY != 0 { + mask |= FileEventModified + } + if event.Mask&unix.IN_DELETE_SELF != 0 { + mask |= FileEventDeleted + } + + // Send event to channel or wait for close. + select { + case <-ctx.Done(): + return ctx.Err() + case w.events <- FileEvent{ + Name: name, + Mask: mask, + }: + } + } +} + +func (w *InotifyFileWatcher) wait(ctx context.Context) error { + for { + n, err := unix.EpollWait(w.epoll.fd, w.epoll.events, -1) + if n == 0 || err == unix.EINTR { + continue + } else if err != nil { + return err + } + + // Read events to see if we have data available on inotify or if we are awaken. + var hasData bool + for _, event := range w.epoll.events[:n] { + switch event.Fd { + case int32(w.inotify.fd): // inotify file descriptor + hasData = hasData || event.Events&(unix.EPOLLHUP|unix.EPOLLERR|unix.EPOLLIN) != 0 + + case int32(w.epoll.r): // epoll file descriptor + if _, err := unix.Read(w.epoll.r, make([]byte, 1024)); err != nil && err != unix.EAGAIN { + return fmt.Errorf("epoll pipe error: %w", err) + } + } + } + + // Check if context is closed and then exit if data is available. + if err := ctx.Err(); err != nil { + return err + } else if hasData { + return nil + } + } +} + +func (w *InotifyFileWatcher) wake() error { + if _, err := unix.Write(w.epoll.w, []byte{0}); err != nil && err != unix.EAGAIN { + return err + } + return nil +} diff --git a/internal/file_watcher_test.go b/internal/file_watcher_test.go new file mode 100644 index 00000000..dd767154 --- /dev/null +++ b/internal/file_watcher_test.go @@ -0,0 +1,211 @@ +package internal_test + +import ( + "database/sql" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/benbjohnson/litestream/internal" + _ "github.com/mattn/go-sqlite3" +) + +func TestFileWatcher(t *testing.T) { + t.Run("WriteAndRemove", func(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "db") + + w := internal.NewFileWatcher() + if err := w.Open(); err != nil { + t.Fatal(err) + } + defer w.Close() + + db, err := sql.Open("sqlite3", dbPath) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + if _, err := db.Exec(`PRAGMA journal_mode = wal`); err != nil { + t.Fatal(err) + } else if _, err := db.Exec(`CREATE TABLE t (x)`); err != nil { + t.Fatal(err) + } + + if err := w.Watch(dbPath + "-wal"); err != nil { + t.Fatal(err) + } + + // Write to the WAL file & ensure a "modified" event occurs. + if _, err := db.Exec(`INSERT INTO t (x) VALUES (1)`); err != nil { + t.Fatal(err) + } + + select { + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for event") + case event := <-w.Events(): + if got, want := event.Name, dbPath+"-wal"; got != want { + t.Fatalf("name=%s, want %s", got, want) + } else if got, want := event.Mask, internal.FileEventModified; got != want { + t.Fatalf("mask=0x%02x, want 0x%02x", got, want) + } + } + + // Flush any duplicate events. + drainFileEventChannel(w.Events()) + + // Close database and ensure checkpointed WAL creates a "delete" event. + if err := db.Close(); err != nil { + t.Fatal(err) + } + + select { + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for event") + case event := <-w.Events(): + if got, want := event.Name, dbPath+"-wal"; got != want { + t.Fatalf("name=%s, want %s", got, want) + } else if got, want := event.Mask, internal.FileEventDeleted; got != want { + t.Fatalf("mask=0x%02x, want 0x%02x", got, want) + } + } + }) + + t.Run("LargeTx", func(t *testing.T) { + w := internal.NewFileWatcher() + if err := w.Open(); err != nil { + t.Fatal(err) + } + defer w.Close() + + dbPath := filepath.Join(t.TempDir(), "db") + db, err := sql.Open("sqlite3", dbPath) + if err != nil { + t.Fatal(err) + } else if _, err := db.Exec(`PRAGMA cache_size = 4`); err != nil { + t.Fatal(err) + } else if _, err := db.Exec(`PRAGMA journal_mode = wal`); err != nil { + t.Fatal(err) + } else if _, err := db.Exec(`CREATE TABLE t (x)`); err != nil { + t.Fatal(err) + } + defer db.Close() + + if err := w.Watch(dbPath + "-wal"); err != nil { + t.Fatal(err) + } + + // Start a transaction to ensure writing large data creates multiple write events. + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer func() { _ = tx.Rollback() }() + + // Write enough data to require a spill. + for i := 0; i < 100; i++ { + if _, err := tx.Exec(`INSERT INTO t (x) VALUES (?)`, strings.Repeat("x", 512)); err != nil { + t.Fatal(err) + } + } + + // Ensure spill writes to disk. + select { + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for event") + case event := <-w.Events(): + if got, want := event.Name, dbPath+"-wal"; got != want { + t.Fatalf("name=%s, want %s", got, want) + } else if got, want := event.Mask, internal.FileEventModified; got != want { + t.Fatalf("mask=0x%02x, want 0x%02x", got, want) + } + } + + // Flush any duplicate events. + drainFileEventChannel(w.Events()) + + if err := tx.Commit(); err != nil { + t.Fatal(err) + } + + // Final commit should spill remaining pages and cause another write event. + select { + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for event") + case event := <-w.Events(): + if got, want := event.Name, dbPath+"-wal"; got != want { + t.Fatalf("name=%s, want %s", got, want) + } else if got, want := event.Mask, internal.FileEventModified; got != want { + t.Fatalf("mask=0x%02x, want 0x%02x", got, want) + } + } + }) + + t.Run("WatchBeforeCreate", func(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "db") + + w := internal.NewFileWatcher() + if err := w.Open(); err != nil { + t.Fatal(err) + } + defer w.Close() + + if err := w.Watch(dbPath); err != nil { + t.Fatal(err) + } else if err := w.Watch(dbPath + "-wal"); err != nil { + t.Fatal(err) + } + + db, err := sql.Open("sqlite3", dbPath) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + if _, err := db.Exec(`CREATE TABLE t (x)`); err != nil { + t.Fatal(err) + } + + // Wait for main database creation event. + waitForFileEvent(t, w.Events(), internal.FileEvent{Name: dbPath, Mask: internal.FileEventCreated}) + + // Write to the WAL file & ensure a "modified" event occurs. + if _, err := db.Exec(`PRAGMA journal_mode = wal`); err != nil { + t.Fatal(err) + } else if _, err := db.Exec(`INSERT INTO t (x) VALUES (1)`); err != nil { + t.Fatal(err) + } + + // Wait for WAL creation event. + waitForFileEvent(t, w.Events(), internal.FileEvent{Name: dbPath + "-wal", Mask: internal.FileEventCreated}) + }) +} + +func drainFileEventChannel(ch <-chan internal.FileEvent) { + for { + select { + case <-time.After(100 * time.Millisecond): + return + case <-ch: + } + } +} + +func waitForFileEvent(tb testing.TB, ch <-chan internal.FileEvent, want internal.FileEvent) { + tb.Helper() + + timeout := time.After(10 * time.Second) + + for { + select { + case <-timeout: + tb.Fatalf("timeout waiting for event: %#v", want) + case got := <-ch: + if got == want { + return + } + } + } +} diff --git a/internal/internal.go b/internal/internal.go index b2db3b98..681726a9 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -239,7 +239,7 @@ func TruncateDuration(d time.Duration) time.Duration { return d } -// MD5hash returns a hex-encoded MD5 hash of b. -func MD5hash(b []byte) string { +// MD5Hash returns a hex-encoded MD5 hash of b. +func MD5Hash(b []byte) string { return fmt.Sprintf("%x", md5.Sum(b)) } diff --git a/replica.go b/replica.go index 95b30e5a..e401d8c2 100644 --- a/replica.go +++ b/replica.go @@ -662,7 +662,7 @@ func (r *Replica) monitor(ctx context.Context) { } // Fetch new notify channel before replicating data. - notify = r.db.Notify() + notify = r.db.WALNotify() // Synchronize the shadow wal into the replication directory. if err := r.Sync(ctx); err != nil { diff --git a/server.go b/server.go new file mode 100644 index 00000000..bd24ecdc --- /dev/null +++ b/server.go @@ -0,0 +1,188 @@ +package litestream + +import ( + "context" + "fmt" + "strings" + "sync" + + "github.com/benbjohnson/litestream/internal" + "golang.org/x/sync/errgroup" +) + +// Server represents the top-level container. +// It manage databases and routes global file system events. +type Server struct { + mu sync.Mutex + dbs map[string]*DB // databases by path + watcher internal.FileWatcher + + ctx context.Context + cancel func() + errgroup errgroup.Group +} + +// NewServer returns a new instance of Server. +func NewServer() *Server { + return &Server{ + dbs: make(map[string]*DB), + } +} + +// Open initializes the server and begins watching for file system events. +func (s *Server) Open() error { + s.watcher = internal.NewFileWatcher() + if err := s.watcher.Open(); err != nil { + return err + } + + s.ctx, s.cancel = context.WithCancel(context.Background()) + s.errgroup.Go(func() error { + if err := s.monitor(s.ctx); err != nil && err != context.Canceled { + return fmt.Errorf("server monitor error: %w", err) + } + return nil + }) + return nil +} + +// Close shuts down the server and all databases it manages. +func (s *Server) Close() (err error) { + // Cancel context and wait for goroutines to finish. + s.cancel() + if e := s.errgroup.Wait(); e != nil && err == nil { + err = e + } + + s.mu.Lock() + defer s.mu.Unlock() + + if s.watcher != nil { + if e := s.watcher.Close(); e != nil && err == nil { + err = fmt.Errorf("close watcher: %w", e) + } + } + + for _, db := range s.dbs { + if e := db.Close(); e != nil && err == nil { + err = fmt.Errorf("close db: path=%s err=%w", db.Path(), e) + } + } + s.dbs = make(map[string]*DB) + + return err +} + +// DB returns the database with the given path, if it's managed by the server. +func (s *Server) DB(path string) *DB { + s.mu.Lock() + defer s.mu.Unlock() + return s.dbs[path] +} + +// DBs returns a slice of all databases managed by the server. +func (s *Server) DBs() []*DB { + s.mu.Lock() + defer s.mu.Unlock() + + a := make([]*DB, 0, len(s.dbs)) + for _, db := range s.dbs { + a = append(a, db) + } + return a +} + +// Watch adds a database path to be managed by the server. +func (s *Server) Watch(path string, fn func(path string) (*DB, error)) error { + s.mu.Lock() + defer s.mu.Unlock() + + // TODO: Watch for path if DB or WAL doesn't exist yet. + + // Instantiate DB from factory function. + db, err := fn(path) + if err != nil { + return fmt.Errorf("new database: %w", err) + } + + // Start watching the database for changes. + if err := db.Open(); err != nil { + return fmt.Errorf("open database: %w", err) + } + s.dbs[path] = db + + // Watch for changes on the database file & WAL. + if err := s.watcher.Watch(path); err != nil { + return fmt.Errorf("watch db file: %w", err) + } else if err := s.watcher.Watch(path + "-wal"); err != nil { + return fmt.Errorf("watch wal file: %w", err) + } + + // Kick off an initial sync. + select { + case db.NotifyCh() <- struct{}{}: + default: + } + + return nil +} + +// Unwatch removes a database path from being managed by the server. +func (s *Server) Unwatch(path string) error { + s.mu.Lock() + defer s.mu.Unlock() + + db := s.dbs[path] + if db == nil { + return nil + } + delete(s.dbs, path) + + // Stop watching for changes on the database WAL. + if err := s.watcher.Unwatch(path + "-wal"); err != nil { + return fmt.Errorf("unwatch file: %w", err) + } + + // Shut down database. + if err := db.Close(); err != nil { + return fmt.Errorf("close db: %w", err) + } + + return nil +} + +// monitor runs in a separate goroutine and dispatches notifications to managed DBs. +func (s *Server) monitor(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case event := <-s.watcher.Events(): + if err := s.dispatchFileEvent(ctx, event); err != nil { + return err + } + } + } +} + +// dispatchFileEvent dispatches a notification to the database which owns the file. +func (s *Server) dispatchFileEvent(ctx context.Context, event internal.FileEvent) error { + path := event.Name + path = strings.TrimSuffix(path, "-wal") + + db := s.DB(path) + if db == nil { + return nil + } + + // TODO: If deleted, remove from server and close DB. + + select { + case <-ctx.Done(): + return ctx.Err() + case db.NotifyCh() <- struct{}{}: + return nil // notify db + default: + return nil // already pending notification, skip + } +} diff --git a/server_test.go b/server_test.go new file mode 100644 index 00000000..3d7601f0 --- /dev/null +++ b/server_test.go @@ -0,0 +1 @@ +package litestream_test