From a34790e494ffc76fb64b0ec25bfd0a5db58d1798 Mon Sep 17 00:00:00 2001 From: positiveblue Date: Fri, 5 Nov 2021 12:55:59 -0700 Subject: [PATCH 1/6] account/watcher: add `Controller` interface Use the new watcher interface in the account manager instead of of a specific struct type. --- account/manager.go | 30 +- account/watcher/{watcher.go => controller.go} | 147 +++++----- account/watcher/interfaces.go | 46 +++ account/watcher/mock_test.go | 48 ---- account/watcher/watcher_test.go | 269 ------------------ 5 files changed, 136 insertions(+), 404 deletions(-) rename account/watcher/{watcher.go => controller.go} (75%) create mode 100644 account/watcher/interfaces.go delete mode 100644 account/watcher/mock_test.go delete mode 100644 account/watcher/watcher_test.go diff --git a/account/manager.go b/account/manager.go index 63ebeb6d0..a11efaba9 100644 --- a/account/manager.go +++ b/account/manager.go @@ -127,8 +127,8 @@ type Manager struct { started sync.Once stopped sync.Once - cfg ManagerConfig - watcher *watcher.Watcher + cfg ManagerConfig + watcherCtrl watcher.Controller // pendingBatchMtx guards access to any database calls involving pending // batches. This is mostly used to prevent race conditions when handling @@ -153,7 +153,7 @@ func NewManager(cfg *ManagerConfig) *Manager { quit: make(chan struct{}), } - m.watcher = watcher.New(&watcher.Config{ + m.watcherCtrl = watcher.NewController(&watcher.Config{ ChainNotifier: cfg.ChainNotifier, HandleAccountConf: m.handleAccountConf, HandleAccountSpend: m.handleAccountSpend, @@ -178,7 +178,7 @@ func (m *Manager) start() error { // We'll start by resuming all of our accounts. This requires the // watcher to be started first. - if err := m.watcher.Start(); err != nil { + if err := m.watcherCtrl.Start(); err != nil { return err } @@ -229,7 +229,7 @@ func (m *Manager) start() error { // Stop safely stops any ongoing operations within the Manager. func (m *Manager) Stop() { m.stopped.Do(func() { - m.watcher.Stop() + m.watcherCtrl.Stop() close(m.quit) m.wg.Wait() @@ -382,8 +382,8 @@ func (m *Manager) WatchMatchedAccounts(ctx context.Context, // canceling all previous spend and confirmation watchers. We // then only watch the latest batch and once it confirms, create // a new spend watcher on that. - m.watcher.CancelAccountSpend(matchedAccount) - m.watcher.CancelAccountConf(matchedAccount) + m.watcherCtrl.CancelAccountSpend(matchedAccount) + m.watcherCtrl.CancelAccountConf(matchedAccount) // After taking part in a batch, the account is either pending // closed because it was used up or pending batch update because @@ -620,7 +620,7 @@ func (m *Manager) resumeAccount(ctx context.Context, account *Account, // nolint ) log.Infof("Waiting for %v confirmation(s) of account %x", numConfs, account.TraderKey.PubKey.SerializeCompressed()) - err = m.watcher.WatchAccountConf( + err = m.watcherCtrl.WatchAccountConf( account.TraderKey.PubKey, account.OutPoint.Hash, accountOutput.PkScript, numConfs, account.HeightHint, ) @@ -657,7 +657,7 @@ func (m *Manager) resumeAccount(ctx context.Context, account *Account, // nolint ) log.Infof("Waiting for %v confirmation(s) of account %x", numConfs, account.TraderKey.PubKey.SerializeCompressed()) - err = m.watcher.WatchAccountConf( + err = m.watcherCtrl.WatchAccountConf( account.TraderKey.PubKey, account.OutPoint.Hash, accountOutput.PkScript, numConfs, account.HeightHint, ) @@ -706,7 +706,7 @@ func (m *Manager) resumeAccount(ctx context.Context, account *Account, // nolint log.Infof("Waiting for %v confirmation(s) of expired account %x", numConfs, account.TraderKey.PubKey.SerializeCompressed()) - err = m.watcher.WatchAccountConf( + err = m.watcherCtrl.WatchAccountConf( account.TraderKey.PubKey, account.OutPoint.Hash, accountOutput.PkScript, numConfs, account.HeightHint, ) @@ -721,7 +721,7 @@ func (m *Manager) resumeAccount(ctx context.Context, account *Account, // nolint log.Infof("Watching expired account %x for spend", account.TraderKey.PubKey.SerializeCompressed()) - err = m.watcher.WatchAccountSpend( + err = m.watcherCtrl.WatchAccountSpend( account.TraderKey.PubKey, account.OutPoint, accountOutput.PkScript, account.HeightHint, ) @@ -745,7 +745,7 @@ func (m *Manager) resumeAccount(ctx context.Context, account *Account, // nolint log.Infof("Watching account %x for spend", account.TraderKey.PubKey.SerializeCompressed()) - err = m.watcher.WatchAccountSpend( + err = m.watcherCtrl.WatchAccountSpend( account.TraderKey.PubKey, account.OutPoint, accountOutput.PkScript, account.HeightHint, ) @@ -836,7 +836,7 @@ func (m *Manager) handleStateOpen(ctx context.Context, account *Account) error { return err } - err = m.watcher.WatchAccountSpend( + err = m.watcherCtrl.WatchAccountSpend( account.TraderKey.PubKey, account.OutPoint, accountOutput.PkScript, account.HeightHint, ) @@ -844,7 +844,7 @@ func (m *Manager) handleStateOpen(ctx context.Context, account *Account) error { return fmt.Errorf("unable to watch for spend: %v", err) } - err = m.watcher.WatchAccountExpiration( + err = m.watcherCtrl.WatchAccountExpiration( account.TraderKey.PubKey, account.Expiry, ) if err != nil { @@ -1253,7 +1253,7 @@ func (m *Manager) RenewAccount(ctx context.Context, // Begin to track the new account expiration, which will overwrite the // existing expiration request. - err = m.watcher.WatchAccountExpiration(traderKey, modifiedAccount.Expiry) + err = m.watcherCtrl.WatchAccountExpiration(traderKey, modifiedAccount.Expiry) if err != nil { return nil, nil, err } diff --git a/account/watcher/watcher.go b/account/watcher/controller.go similarity index 75% rename from account/watcher/watcher.go rename to account/watcher/controller.go index d21ac630c..09f05a492 100644 --- a/account/watcher/watcher.go +++ b/account/watcher/controller.go @@ -24,7 +24,7 @@ type expiryReq struct { expiry uint32 } -// Config contains all of the Watcher's dependencies in order to carry out its +// Config contains all of the Controller's dependencies in order to carry out its // duties. type Config struct { // ChainNotifier is responsible for requesting confirmation and spend @@ -47,9 +47,8 @@ type Config struct { HandleAccountExpiry func(*btcec.PublicKey, uint32) error } -// Watcher is responsible for the on-chain interaction of an account, whether -// that is confirmation or spend. -type Watcher struct { +// controller implements the Controller interface +type controller struct { started sync.Once stopped sync.Once @@ -66,9 +65,13 @@ type Watcher struct { confCancels map[[33]byte]func() } -// New instantiates a new chain watcher backed by the given config. -func New(cfg *Config) *Watcher { - return &Watcher{ +// Compile time assertion that controller implements the Controller interface? +var _ Controller = (*controller)(nil) + +// NewController returns an internal struct type that implements the +// Controller interface. +func NewController(cfg *Config) *controller { // nolint:golint + return &controller{ cfg: *cfg, expiryReqs: make(chan *expiryReq), quit: make(chan struct{}), @@ -78,58 +81,58 @@ func New(cfg *Config) *Watcher { } // Start allows the Watcher to begin accepting watch requests. -func (w *Watcher) Start() error { +func (c *controller) Start() error { var err error - w.started.Do(func() { - err = w.start() + c.started.Do(func() { + err = c.start() }) return err } // start allows the Watcher to begin accepting watch requests. -func (w *Watcher) start() error { +func (c *controller) start() error { ctxc, cancel := context.WithCancel(context.Background()) - blockChan, errChan, err := w.cfg.ChainNotifier.RegisterBlockEpochNtfn( + blockChan, errChan, err := c.cfg.ChainNotifier.RegisterBlockEpochNtfn( ctxc, ) if err != nil { cancel() return err } - w.ctxCancels = append(w.ctxCancels, cancel) + c.ctxCancels = append(c.ctxCancels, cancel) - w.wg.Add(1) - go w.expiryHandler(blockChan, errChan) + c.wg.Add(1) + go c.expiryHandler(blockChan, errChan) return nil } // Stop safely stops any ongoing requests within the Watcher. -func (w *Watcher) Stop() { - w.stopped.Do(func() { - close(w.quit) - w.wg.Wait() +func (c *controller) Stop() { + c.stopped.Do(func() { + close(c.quit) + c.wg.Wait() - for _, cancel := range w.ctxCancels { + for _, cancel := range c.ctxCancels { cancel() } - w.cancelMtx.Lock() - for _, cancel := range w.spendCancels { + c.cancelMtx.Lock() + for _, cancel := range c.spendCancels { cancel() } - for _, cancel := range w.confCancels { + for _, cancel := range c.confCancels { cancel() } - w.cancelMtx.Unlock() + c.cancelMtx.Unlock() }) } // expiryHandler receives block notifications to determine when accounts expire. // // NOTE: This must be run as a goroutine. -func (w *Watcher) expiryHandler(blockChan chan int32, errChan chan error) { - defer w.wg.Done() +func (c *controller) expiryHandler(blockChan chan int32, errChan chan error) { + defer c.wg.Done() var ( // bestHeight is the height we believe the current chain is at. @@ -152,7 +155,7 @@ func (w *Watcher) expiryHandler(blockChan chan int32, errChan chan error) { case err := <-errChan: log.Errorf("Unable to receive initial block notification: %v", err) - case <-w.quit: + case <-c.quit: return } @@ -177,7 +180,7 @@ func (w *Watcher) expiryHandler(blockChan chan int32, errChan chan error) { continue } - err := w.cfg.HandleAccountExpiry( + err := c.cfg.HandleAccountExpiry( traderKey, bestHeight, ) if err != nil { @@ -196,13 +199,13 @@ func (w *Watcher) expiryHandler(blockChan chan int32, errChan chan error) { err) // A new watch expiry request has been received for an account. - case req := <-w.expiryReqs: + case req := <-c.expiryReqs: var accountKey [33]byte copy(accountKey[:], req.traderKey.SerializeCompressed()) // If it's already expired, we don't need to track it. if req.expiry <= bestHeight { - err := w.cfg.HandleAccountExpiry( + err := c.cfg.HandleAccountExpiry( req.traderKey, bestHeight, ) if err != nil { @@ -221,7 +224,7 @@ func (w *Watcher) expiryHandler(blockChan chan int32, errChan chan error) { expirationsPerHeight[req.expiry], req.traderKey, ) - case <-w.quit: + case <-c.quit: return } } @@ -232,33 +235,33 @@ func (w *Watcher) expiryHandler(blockChan chan int32, errChan chan error) { // // NOTE: If there is a previous conf watcher for the given account that has not // finished yet, it will be canceled! -func (w *Watcher) WatchAccountConf(traderKey *btcec.PublicKey, +func (c *controller) WatchAccountConf(traderKey *btcec.PublicKey, txHash chainhash.Hash, script []byte, numConfs, heightHint uint32) error { - w.cancelMtx.Lock() - defer w.cancelMtx.Unlock() + c.cancelMtx.Lock() + defer c.cancelMtx.Unlock() var traderKeyRaw [33]byte copy(traderKeyRaw[:], traderKey.SerializeCompressed()) // Cancel a previous conf watcher if one still exists. - cancel, ok := w.confCancels[traderKeyRaw] + cancel, ok := c.confCancels[traderKeyRaw] if ok { cancel() } ctxc, cancel := context.WithCancel(context.Background()) - confChan, errChan, err := w.cfg.ChainNotifier.RegisterConfirmationsNtfn( + confChan, errChan, err := c.cfg.ChainNotifier.RegisterConfirmationsNtfn( ctxc, &txHash, script, int32(numConfs), int32(heightHint), ) if err != nil { cancel() return err } - w.confCancels[traderKeyRaw] = cancel + c.confCancels[traderKeyRaw] = cancel - w.wg.Add(1) - go w.waitForAccountConf(traderKey, traderKeyRaw, confChan, errChan) + c.wg.Add(1) + go c.waitForAccountConf(traderKey, traderKeyRaw, confChan, errChan) return nil } @@ -267,21 +270,21 @@ func (w *Watcher) WatchAccountConf(traderKey *btcec.PublicKey, // necessary steps once confirmed. // // NOTE: This method must be run as a goroutine. -func (w *Watcher) waitForAccountConf(traderKey *btcec.PublicKey, +func (c *controller) waitForAccountConf(traderKey *btcec.PublicKey, traderKeyRaw [33]byte, confChan chan *chainntnfs.TxConfirmation, errChan chan error) { defer func() { - w.wg.Done() + c.wg.Done() - w.cancelMtx.Lock() - delete(w.confCancels, traderKeyRaw) - w.cancelMtx.Unlock() + c.cancelMtx.Lock() + delete(c.confCancels, traderKeyRaw) + c.cancelMtx.Unlock() }() select { case conf := <-confChan: - if err := w.cfg.HandleAccountConf(traderKey, conf); err != nil { + if err := c.cfg.HandleAccountConf(traderKey, conf); err != nil { log.Errorf("Unable to handle confirmation for account "+ "%x: %v", traderKey.SerializeCompressed(), err) } @@ -300,7 +303,7 @@ func (w *Watcher) waitForAccountConf(traderKey *btcec.PublicKey, traderKey.SerializeCompressed(), err) } - case <-w.quit: + case <-c.quit: return } } @@ -310,33 +313,33 @@ func (w *Watcher) waitForAccountConf(traderKey *btcec.PublicKey, // // NOTE: If there is a previous spend watcher for the given account that has not // finished yet, it will be canceled! -func (w *Watcher) WatchAccountSpend(traderKey *btcec.PublicKey, +func (c *controller) WatchAccountSpend(traderKey *btcec.PublicKey, accountPoint wire.OutPoint, script []byte, heightHint uint32) error { - w.cancelMtx.Lock() - defer w.cancelMtx.Unlock() + c.cancelMtx.Lock() + defer c.cancelMtx.Unlock() var traderKeyRaw [33]byte copy(traderKeyRaw[:], traderKey.SerializeCompressed()) // Cancel a previous spend watcher if one still exists. - cancel, ok := w.spendCancels[traderKeyRaw] + cancel, ok := c.spendCancels[traderKeyRaw] if ok { cancel() } ctxc, cancel := context.WithCancel(context.Background()) - spendChan, errChan, err := w.cfg.ChainNotifier.RegisterSpendNtfn( + spendChan, errChan, err := c.cfg.ChainNotifier.RegisterSpendNtfn( ctxc, &accountPoint, script, int32(heightHint), ) if err != nil { cancel() return err } - w.spendCancels[traderKeyRaw] = cancel + c.spendCancels[traderKeyRaw] = cancel - w.wg.Add(1) - go w.waitForAccountSpend(traderKey, traderKeyRaw, spendChan, errChan) + c.wg.Add(1) + go c.waitForAccountSpend(traderKey, traderKeyRaw, spendChan, errChan) return nil } @@ -345,21 +348,21 @@ func (w *Watcher) WatchAccountSpend(traderKey *btcec.PublicKey, // steps once spent. // // NOTE: This method must be run as a goroutine. -func (w *Watcher) waitForAccountSpend(traderKey *btcec.PublicKey, +func (c *controller) waitForAccountSpend(traderKey *btcec.PublicKey, traderKeyRaw [33]byte, spendChan chan *chainntnfs.SpendDetail, errChan chan error) { defer func() { - w.wg.Done() + c.wg.Done() - w.cancelMtx.Lock() - delete(w.spendCancels, traderKeyRaw) - w.cancelMtx.Unlock() + c.cancelMtx.Lock() + delete(c.spendCancels, traderKeyRaw) + c.cancelMtx.Unlock() }() select { case spend := <-spendChan: - err := w.cfg.HandleAccountSpend(traderKey, spend) + err := c.cfg.HandleAccountSpend(traderKey, spend) if err != nil { log.Errorf("Unable to handle spend for account %x: %v", traderKey.SerializeCompressed(), err) @@ -378,7 +381,7 @@ func (w *Watcher) waitForAccountSpend(traderKey *btcec.PublicKey, "%v", traderKey.SerializeCompressed(), err) } - case <-w.quit: + case <-c.quit: return } } @@ -386,31 +389,31 @@ func (w *Watcher) waitForAccountSpend(traderKey *btcec.PublicKey, // WatchAccountExpiration watches for the expiration of an account on-chain. // Successive calls for the same account will cancel any previous expiration // watch requests and the new expiration will be tracked instead. -func (w *Watcher) WatchAccountExpiration(traderKey *btcec.PublicKey, +func (c *controller) WatchAccountExpiration(traderKey *btcec.PublicKey, expiry uint32) error { select { - case w.expiryReqs <- &expiryReq{ + case c.expiryReqs <- &expiryReq{ traderKey: traderKey, expiry: expiry, }: return nil - case <-w.quit: + case <-c.quit: return errors.New("watcher shutting down") } } // CancelAccountSpend cancels the spend watcher of the given account, if one is // active. -func (w *Watcher) CancelAccountSpend(traderKey *btcec.PublicKey) { - w.cancelMtx.Lock() - defer w.cancelMtx.Unlock() +func (c *controller) CancelAccountSpend(traderKey *btcec.PublicKey) { + c.cancelMtx.Lock() + defer c.cancelMtx.Unlock() var traderKeyRaw [33]byte copy(traderKeyRaw[:], traderKey.SerializeCompressed()) - cancel, ok := w.spendCancels[traderKeyRaw] + cancel, ok := c.spendCancels[traderKeyRaw] if ok { cancel() } @@ -418,14 +421,14 @@ func (w *Watcher) CancelAccountSpend(traderKey *btcec.PublicKey) { // CancelAccountConf cancels the conf watcher of the given account, if one is // active. -func (w *Watcher) CancelAccountConf(traderKey *btcec.PublicKey) { - w.cancelMtx.Lock() - defer w.cancelMtx.Unlock() +func (c *controller) CancelAccountConf(traderKey *btcec.PublicKey) { + c.cancelMtx.Lock() + defer c.cancelMtx.Unlock() var traderKeyRaw [33]byte copy(traderKeyRaw[:], traderKey.SerializeCompressed()) - cancel, ok := w.confCancels[traderKeyRaw] + cancel, ok := c.confCancels[traderKeyRaw] if ok { cancel() } diff --git a/account/watcher/interfaces.go b/account/watcher/interfaces.go new file mode 100644 index 000000000..c2cd4f7f3 --- /dev/null +++ b/account/watcher/interfaces.go @@ -0,0 +1,46 @@ +package watcher + +import ( + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" +) + +// Controller is the interface used by other components to communicate with the +// watcher. +type Controller interface { + // Start allows the Controller to begin accepting watch requests. + Start() error + + // Stop safely stops any ongoing requests within the Controller. + Stop() + + // WatchAccountConf watches a new account on-chain for its confirmation. Only + // one conf watcher per account can be used at any time. + // + // NOTE: If there is a previous conf watcher for the given account that has not + // finished yet, it will be canceled! + WatchAccountConf(traderKey *btcec.PublicKey, + txHash chainhash.Hash, script []byte, numConfs, heightHint uint32) error + + // CancelAccountConf cancels the conf watcher of the given account, if one is + // active. + CancelAccountConf(traderKey *btcec.PublicKey) + + // WatchAccountSpend watches for the spend of an account. Only one spend watcher + // per account can be used at any time. + // + // NOTE: If there is a previous spend watcher for the given account that has not + // finished yet, it will be canceled! + WatchAccountSpend(traderKey *btcec.PublicKey, + accountPoint wire.OutPoint, script []byte, heightHint uint32) error + + // CancelAccountSpend cancels the spend watcher of the given account, if one is + // active. + CancelAccountSpend(traderKey *btcec.PublicKey) + + // WatchAccountExpiration watches for the expiration of an account on-chain. + // Successive calls for the same account will cancel any previous expiration + // watch requests and the new expiration will be tracked instead. + WatchAccountExpiration(traderKey *btcec.PublicKey, expiry uint32) error +} diff --git a/account/watcher/mock_test.go b/account/watcher/mock_test.go deleted file mode 100644 index 14b80f0ad..000000000 --- a/account/watcher/mock_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package watcher - -import ( - "context" - - "github.com/btcsuite/btcd/chaincfg/chainhash" - "github.com/btcsuite/btcd/wire" - "github.com/lightninglabs/lndclient" - "github.com/lightningnetwork/lnd/chainntnfs" -) - -type mockChainNotifier struct { - lndclient.ChainNotifierClient - - confChan chan *chainntnfs.TxConfirmation - spendChan chan *chainntnfs.SpendDetail - blockChan chan int32 - errChan chan error -} - -func newMockChainNotifier() *mockChainNotifier { - return &mockChainNotifier{ - confChan: make(chan *chainntnfs.TxConfirmation), - spendChan: make(chan *chainntnfs.SpendDetail), - blockChan: make(chan int32), - errChan: make(chan error), - } -} - -func (n *mockChainNotifier) RegisterConfirmationsNtfn(ctx context.Context, - txid *chainhash.Hash, pkScript []byte, numConfs, - heightHint int32) (chan *chainntnfs.TxConfirmation, chan error, error) { - - return n.confChan, n.errChan, nil -} - -func (n *mockChainNotifier) RegisterSpendNtfn(ctx context.Context, - outpoint *wire.OutPoint, pkScript []byte, - heightHint int32) (chan *chainntnfs.SpendDetail, chan error, error) { - - return n.spendChan, n.errChan, nil -} - -func (n *mockChainNotifier) RegisterBlockEpochNtfn( - ctx context.Context) (chan int32, chan error, error) { - - return n.blockChan, n.errChan, nil -} diff --git a/account/watcher/watcher_test.go b/account/watcher/watcher_test.go deleted file mode 100644 index 6879fcefa..000000000 --- a/account/watcher/watcher_test.go +++ /dev/null @@ -1,269 +0,0 @@ -package watcher - -import ( - "encoding/hex" - "testing" - "time" - - "github.com/btcsuite/btcd/btcec" - "github.com/btcsuite/btcd/wire" - "github.com/lightningnetwork/lnd/chainntnfs" -) - -const ( - timeout = 500 * time.Millisecond -) - -var ( - zeroOutPoint wire.OutPoint - rawTestTraderKey, _ = hex.DecodeString("02d0de0999f50eaacaae5b6e178eec7c8bd99dd797bc9f7cfb497e2188884d59f3") - testTraderKey, _ = btcec.ParsePubKey(rawTestTraderKey, btcec.S256()) - testScript, _ = hex.DecodeString("00149589c15e7a8a8065f75aad5f3337cfccf909174a") -) - -// TestWatcherConf ensures that the watcher performs its expected operations -// once an account confirmation has been detected. -func TestWatcherConf(t *testing.T) { - t.Parallel() - - // Set up the required dependencies of the Watcher. - notifier := newMockChainNotifier() - - // The HandleAccountConf closure will use a signal to indicate that it's - // been invoked once a confirmation notification is received. - confSignal := make(chan struct{}) - handleConf := func(*btcec.PublicKey, *chainntnfs.TxConfirmation) error { - close(confSignal) - return nil - } - - watcher := New(&Config{ - ChainNotifier: notifier, - HandleAccountConf: handleConf, - }) - if err := watcher.Start(); err != nil { - t.Fatalf("unable to start watcher: %v", err) - } - defer watcher.Stop() - - // Watch for an account's confirmation. - err := watcher.WatchAccountConf( - testTraderKey, zeroOutPoint.Hash, testScript, 1, 1, - ) - if err != nil { - t.Fatalf("unable to watch account conf: %v", err) - } - - // HandleAccountConf should not be invoked until after the confirmation. - select { - case <-confSignal: - t.Fatal("unexpected conf signal") - case <-time.After(timeout): - } - - select { - case notifier.confChan <- &chainntnfs.TxConfirmation{}: - case <-time.After(timeout): - t.Fatal("unable to notify conf") - } - - select { - case <-confSignal: - case <-time.After(timeout): - t.Fatal("expected conf signal") - } -} - -// TestWatcherSpend ensures that the watcher performs its expected operations -// once an account spend has been detected. -func TestWatcherSpend(t *testing.T) { - t.Parallel() - - // Set up the required dependencies of the Watcher. - notifier := newMockChainNotifier() - - // The HandleAccountSpend closure will use a signal to indicate that - // it's been invoked once a spend notification is received. - spendSignal := make(chan struct{}) - handleSpend := func(*btcec.PublicKey, *chainntnfs.SpendDetail) error { - close(spendSignal) - return nil - } - - watcher := New(&Config{ - ChainNotifier: notifier, - HandleAccountSpend: handleSpend, - }) - if err := watcher.Start(); err != nil { - t.Fatalf("unable to start watcher: %v", err) - } - defer watcher.Stop() - - // Watch for an account's spend. - err := watcher.WatchAccountSpend( - testTraderKey, zeroOutPoint, testScript, 1, - ) - if err != nil { - t.Fatalf("unable to watch account spend: %v", err) - } - - // HandleAccountSpend should not be invoked until after the spend. - select { - case <-spendSignal: - t.Fatal("unexpected spend signal") - case <-time.After(timeout): - } - - select { - case notifier.spendChan <- &chainntnfs.SpendDetail{}: - case <-time.After(timeout): - t.Fatal("unable to notify spend") - } - - select { - case <-spendSignal: - case <-time.After(timeout): - t.Fatal("expected spend signal") - } -} - -// TestWatcherExpiry ensures that the watcher performs its expected operations -// once an account expiration has been detected. -func TestWatcherExpiry(t *testing.T) { - t.Parallel() - - const ( - startHeight = 100 - expiryHeight = startHeight * 2 - ) - - // Set up the required dependencies of the Watcher. - notifier := newMockChainNotifier() - - // The HandleAccountExpiry closure will use a signal to indicate that - // it's been invoked once an expiry notification is received. - expirySignal := make(chan struct{}) - handleExpiry := func(*btcec.PublicKey, uint32) error { - close(expirySignal) - return nil - } - - watcher := New(&Config{ - ChainNotifier: notifier, - HandleAccountExpiry: handleExpiry, - }) - if err := watcher.Start(); err != nil { - t.Fatalf("unable to start watcher: %v", err) - } - defer watcher.Stop() - - select { - case notifier.blockChan <- startHeight: - case <-time.After(timeout): - t.Fatal("unable to notify block") - } - - // Watch for an account's expiration that has yet to expire. - err := watcher.WatchAccountExpiration(testTraderKey, expiryHeight) - if err != nil { - t.Fatalf("unable to watch account expiry: %v", err) - } - - // HandleAccountExpiry should not be invoked until after the expiration. - select { - case <-expirySignal: - t.Fatal("unexpected expiry signal") - case <-time.After(timeout): - } - - // Override the existing watch request with a new one that expires at - // double the height. - err = watcher.WatchAccountExpiration(testTraderKey, expiryHeight*2) - if err != nil { - t.Fatalf("unable to watch account expiry: %v", err) - } - - // HandleAccountExpiry should still not be invoked yet. - select { - case <-expirySignal: - t.Fatal("unexpected expiry signal") - case <-time.After(timeout): - } - - // Notify the first expiration height. This should not cause - // HandleAccountExpiry to be invoked as the second request overwrote it. - select { - case notifier.blockChan <- expiryHeight: - case <-time.After(timeout): - t.Fatal("unable to notify expiry") - } - - select { - case <-expirySignal: - t.Fatal("unexpected expiry signal") - case <-time.After(timeout): - } - - // Notify the new expiration height. This should cause - // HandleAccountExpiry to be invoked. - select { - case notifier.blockChan <- expiryHeight * 2: - case <-time.After(timeout): - t.Fatal("unable to notify expiry") - } - - select { - case <-expirySignal: - case <-time.After(timeout): - t.Fatal("expected expiry signal") - } -} - -// TestWatcherAccountAlreadyExpired ensures that the watcher performs its -// expected operations once an account expiration has already happened at the -// time of registration. -func TestWatcherAccountAlreadyExpired(t *testing.T) { - t.Parallel() - - const startHeight = 100 - - // Set up the required dependencies of the Watcher. - notifier := newMockChainNotifier() - - // The HandleAccountExpiry closure will use a signal to indicate that - // it's been invoked once an expiry notification is received. - expirySignal := make(chan struct{}) - handleExpiry := func(*btcec.PublicKey, uint32) error { - close(expirySignal) - return nil - } - - watcher := New(&Config{ - ChainNotifier: notifier, - HandleAccountExpiry: handleExpiry, - }) - if err := watcher.Start(); err != nil { - t.Fatalf("unable to start watcher: %v", err) - } - defer watcher.Stop() - - select { - case notifier.blockChan <- startHeight: - case <-time.After(timeout): - t.Fatal("unable to notify block") - } - - // Watch for an account's expiration that has already expired. - err := watcher.WatchAccountExpiration(testTraderKey, startHeight) - if err != nil { - t.Fatalf("unable to watch account expiry: %v", err) - } - - // HandleAccountExpiry should have been invoked since the expiration was - // already reached at the time of registration. - select { - case <-expirySignal: - case <-time.After(timeout): - t.Fatal("expected expiry signal") - } -} From abd603d765f8a518c9a9cdc6dc6fff8d5ceefdef Mon Sep 17 00:00:00 2001 From: positiveblue Date: Mon, 6 Dec 2021 22:00:48 -0800 Subject: [PATCH 2/6] account/watcher: add `ExpiryWatcher` and `EventHandler` interfaces Split watcher logic in three pices: - Controller: API + message dispatching - ExpiryWatcher: handle account expirations - EventHandler: implementation for each handler --- account/manager.go | 17 +++--- account/watcher/controller.go | 102 ++++++-------------------------- account/watcher/interfaces.go | 32 ++++++++++ account/watcher/watcher.go | 108 ++++++++++++++++++++++++++++++++++ 4 files changed, 165 insertions(+), 94 deletions(-) create mode 100644 account/watcher/watcher.go diff --git a/account/manager.go b/account/manager.go index a11efaba9..a304f7d67 100644 --- a/account/manager.go +++ b/account/manager.go @@ -153,11 +153,10 @@ func NewManager(cfg *ManagerConfig) *Manager { quit: make(chan struct{}), } - m.watcherCtrl = watcher.NewController(&watcher.Config{ - ChainNotifier: cfg.ChainNotifier, - HandleAccountConf: m.handleAccountConf, - HandleAccountSpend: m.handleAccountSpend, - HandleAccountExpiry: m.handleAccountExpiry, + m.watcherCtrl = watcher.NewController(&watcher.CtrlConfig{ + ChainNotifier: cfg.ChainNotifier, + // The manager implements the EventHandler interface + Handlers: m, }) return m @@ -866,9 +865,9 @@ func (m *Manager) handleStateOpen(ctx context.Context, account *Account) error { return nil } -// handleAccountConf takes the necessary steps after detecting the confirmation +// HandleAccountConf takes the necessary steps after detecting the confirmation // of an account on-chain. -func (m *Manager) handleAccountConf(traderKey *btcec.PublicKey, +func (m *Manager) HandleAccountConf(traderKey *btcec.PublicKey, confDetails *chainntnfs.TxConfirmation) error { account, err := m.cfg.Store.Account(traderKey) @@ -919,7 +918,7 @@ func (m *Manager) handleAccountConf(traderKey *btcec.PublicKey, // only track the spend of the latest batch, after it confirmed. So the account // output in the spend transaction should always match our database state if // it was a cooperative spend. -func (m *Manager) handleAccountSpend(traderKey *btcec.PublicKey, +func (m *Manager) HandleAccountSpend(traderKey *btcec.PublicKey, spendDetails *chainntnfs.SpendDetail) error { account, err := m.cfg.Store.Account(traderKey) @@ -1016,7 +1015,7 @@ func (m *Manager) handleAccountSpend(traderKey *btcec.PublicKey, } // handleAccountExpiry marks an account as expired within the database. -func (m *Manager) handleAccountExpiry(traderKey *btcec.PublicKey, +func (m *Manager) HandleAccountExpiry(traderKey *btcec.PublicKey, height uint32) error { account, err := m.cfg.Store.Account(traderKey) diff --git a/account/watcher/controller.go b/account/watcher/controller.go index 09f05a492..c9b37763d 100644 --- a/account/watcher/controller.go +++ b/account/watcher/controller.go @@ -26,25 +26,13 @@ type expiryReq struct { // Config contains all of the Controller's dependencies in order to carry out its // duties. -type Config struct { +type CtrlConfig struct { // ChainNotifier is responsible for requesting confirmation and spend // notifications for accounts. ChainNotifier lndclient.ChainNotifierClient - // HandleAccountConf abstracts the operations that should be performed - // for an account once we detect its confirmation. The account is - // identified by its user sub key (i.e., trader key). - HandleAccountConf func(*btcec.PublicKey, *chainntnfs.TxConfirmation) error - - // HandleAccountSpend abstracts the operations that should be performed - // for an account once we detect its spend. The account is identified by - // its user sub key (i.e., trader key). - HandleAccountSpend func(*btcec.PublicKey, *chainntnfs.SpendDetail) error - - // HandleAccountExpiry the operations that should be perform for an - // account once it's expired. The account is identified by its user sub - // key (i.e., trader key). - HandleAccountExpiry func(*btcec.PublicKey, uint32) error + // Handlers define the handler to be used after receiving every event. + Handlers EventHandler } // controller implements the Controller interface @@ -52,7 +40,9 @@ type controller struct { started sync.Once stopped sync.Once - cfg Config + cfg *CtrlConfig + + watcher ExpiryWatcher expiryReqs chan *expiryReq @@ -70,9 +60,11 @@ var _ Controller = (*controller)(nil) // NewController returns an internal struct type that implements the // Controller interface. -func NewController(cfg *Config) *controller { // nolint:golint +func NewController(cfg *CtrlConfig) *controller { // nolint:golint + watcher := NewExpiryWatcher(cfg.Handlers) return &controller{ - cfg: *cfg, + cfg: cfg, + watcher: watcher, expiryReqs: make(chan *expiryReq), quit: make(chan struct{}), spendCancels: make(map[[33]byte]func()), @@ -134,24 +126,11 @@ func (c *controller) Stop() { func (c *controller) expiryHandler(blockChan chan int32, errChan chan error) { defer c.wg.Done() - var ( - // bestHeight is the height we believe the current chain is at. - bestHeight uint32 - - // expirations keeps track of the current accounts we're - // watching expirations for. - expirations = make(map[[33]byte]uint32) - - // expirationsPerHeight keeps track of all registered accounts - // that expire at a certain height. - expirationsPerHeight = make(map[uint32][]*btcec.PublicKey) - ) - // Wait for the initial block notification to be received before we // begin handling requests. select { case newBlock := <-blockChan: - bestHeight = uint32(newBlock) + c.watcher.NewBlock(uint32(newBlock)) case err := <-errChan: log.Errorf("Unable to receive initial block notification: %v", err) @@ -164,34 +143,7 @@ func (c *controller) expiryHandler(blockChan chan int32, errChan chan error) { // A new block notification has arrived, update our known // height and notify any newly expired accounts. case newBlock := <-blockChan: - bestHeight = uint32(newBlock) - - for _, traderKey := range expirationsPerHeight[bestHeight] { - var accountKey [33]byte - copy(accountKey[:], traderKey.SerializeCompressed()) - - // If the account doesn't exist within the - // expiration set, then the request was - // canceled and there's nothing for us to do. - // Similarly, if the request was updated to - // track a new height, then we can skip it. - curExpiry, ok := expirations[accountKey] - if !ok || bestHeight != curExpiry { - continue - } - - err := c.cfg.HandleAccountExpiry( - traderKey, bestHeight, - ) - if err != nil { - log.Errorf("Unable to handle "+ - "expiration of account %x: %v", - traderKey.SerializeCompressed(), - err) - } - } - - delete(expirationsPerHeight, bestHeight) + c.watcher.NewBlock(uint32(newBlock)) // An error occurred while being sent a block notification. case err := <-errChan: @@ -200,30 +152,9 @@ func (c *controller) expiryHandler(blockChan chan int32, errChan chan error) { // A new watch expiry request has been received for an account. case req := <-c.expiryReqs: - var accountKey [33]byte - copy(accountKey[:], req.traderKey.SerializeCompressed()) - - // If it's already expired, we don't need to track it. - if req.expiry <= bestHeight { - err := c.cfg.HandleAccountExpiry( - req.traderKey, bestHeight, - ) - if err != nil { - log.Errorf("Unable to handle "+ - "expiration of account %x: %v", - req.traderKey.SerializeCompressed(), - err) - } - delete(expirations, accountKey) - - continue - } - - expirations[accountKey] = req.expiry - expirationsPerHeight[req.expiry] = append( - expirationsPerHeight[req.expiry], req.traderKey, + c.watcher.AddAccountExpiration( + req.traderKey, req.expiry, ) - case <-c.quit: return } @@ -284,7 +215,8 @@ func (c *controller) waitForAccountConf(traderKey *btcec.PublicKey, select { case conf := <-confChan: - if err := c.cfg.HandleAccountConf(traderKey, conf); err != nil { + err := c.cfg.Handlers.HandleAccountConf(traderKey, conf) + if err != nil { log.Errorf("Unable to handle confirmation for account "+ "%x: %v", traderKey.SerializeCompressed(), err) } @@ -362,7 +294,7 @@ func (c *controller) waitForAccountSpend(traderKey *btcec.PublicKey, select { case spend := <-spendChan: - err := c.cfg.HandleAccountSpend(traderKey, spend) + err := c.cfg.Handlers.HandleAccountSpend(traderKey, spend) if err != nil { log.Errorf("Unable to handle spend for account %x: %v", traderKey.SerializeCompressed(), err) diff --git a/account/watcher/interfaces.go b/account/watcher/interfaces.go index c2cd4f7f3..01f3afec1 100644 --- a/account/watcher/interfaces.go +++ b/account/watcher/interfaces.go @@ -4,6 +4,7 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/chainntnfs" ) // Controller is the interface used by other components to communicate with the @@ -44,3 +45,34 @@ type Controller interface { // watch requests and the new expiration will be tracked instead. WatchAccountExpiration(traderKey *btcec.PublicKey, expiry uint32) error } + +// EventHandler is the interface used by other components to handle the different +// watcher events. +type EventHandler interface { + // HandleAccountConf abstracts the operations that should be performed + // for an account once we detect its confirmation. The account is + // identified by its user sub key (i.e., trader key). + HandleAccountConf(*btcec.PublicKey, *chainntnfs.TxConfirmation) error + + // HandleAccountSpend abstracts the operations that should be performed + // for an account once we detect its spend. The account is identified by + // its user sub key (i.e., trader key). + HandleAccountSpend(*btcec.PublicKey, *chainntnfs.SpendDetail) error + + // HandleAccountExpiry the operations that should be perform for an + // account once it's expired. The account is identified by its user sub + // key (i.e., trader key). + HandleAccountExpiry(*btcec.PublicKey, uint32) error +} + +// ExpiryWatcher is the interface for the component in charge of the accounts' +// expiration. +type ExpiryWatcher interface { + // NewBlock updates the current bestHeight and handles overdue + // expirations. + NewBlock(bestHeight uint32) + + // AddAccountExpiration creates or updates the existing record for the + // traderKey. + AddAccountExpiration(traderKey *btcec.PublicKey, expiry uint32) +} diff --git a/account/watcher/watcher.go b/account/watcher/watcher.go new file mode 100644 index 000000000..41df56b8b --- /dev/null +++ b/account/watcher/watcher.go @@ -0,0 +1,108 @@ +package watcher + +import ( + "sync" + + "github.com/btcsuite/btcd/btcec" +) + +// expiryWatcher implements the ExpiryWatcher interface +type expiryWatcher struct { + handlers EventHandler + + // bestHeight is the height we believe the current chain is at. + bestHeight uint32 + + // expirations keeps track of the current accounts we're + // watching expirations for. + expirations map[[33]byte]uint32 + + // expirationsPerHeight keeps track of all registered accounts + // that expire at a certain height. + expirationsPerHeight map[uint32][]*btcec.PublicKey + + expirationsMtx sync.Mutex +} + +// NewExpiryWatcher instantiates a new ExpiryWatcher. +func NewExpiryWatcher(handlers EventHandler) *expiryWatcher { // nolint:golint + return &expiryWatcher{ + handlers: handlers, + expirations: make(map[[33]byte]uint32), + expirationsPerHeight: make(map[uint32][]*btcec.PublicKey), + } +} + +// NewBlock updates the current bestHeight. +func (w *expiryWatcher) NewBlock(bestHeight uint32) { + w.expirationsMtx.Lock() + defer w.expirationsMtx.Unlock() + + w.bestHeight = bestHeight + w.overdueExpirations(w.bestHeight) +} + +// overdueExpirations handles the expirations for the given block. +func (w *expiryWatcher) overdueExpirations(blockHeight uint32) { + for _, traderKey := range w.expirationsPerHeight[blockHeight] { + var accountKey [33]byte + copy(accountKey[:], traderKey.SerializeCompressed()) + + // If the account doesn't exist within the + // expiration set, then the request was + // canceled and there's nothing for us to do. + // Similarly, if the request was updated to + // track a new height, then we can skip it. + curExpiry, ok := w.expirations[accountKey] + if !ok || blockHeight != curExpiry { + continue + } + + err := w.handlers.HandleAccountExpiry( + traderKey, blockHeight, + ) + if err != nil { + log.Errorf("Unable to handle "+ + "expiration of account %x: %v", + traderKey.SerializeCompressed(), + err) + } + } + + delete(w.expirationsPerHeight, blockHeight) +} + +// AddAccountExpiration creates or updates the existing record for the traderKey. +func (w *expiryWatcher) AddAccountExpiration(traderKey *btcec.PublicKey, + expiry uint32) { + + w.expirationsMtx.Lock() + defer w.expirationsMtx.Unlock() + + var accountKey [33]byte + copy(accountKey[:], traderKey.SerializeCompressed()) + + // If it's already expired, we don't need to track it. + if expiry <= w.bestHeight { + // Delete the entry from the watcher.expirations + // and handle the expiry in the background. + go func() { + if err := w.handlers.HandleAccountExpiry( + traderKey, w.bestHeight, + ); err != nil { + log.Errorf("Unable to handle "+ + "expiration of account %x: %v", + traderKey.SerializeCompressed(), + err) + } + }() + + delete(w.expirations, accountKey) + return + } + + w.expirations[accountKey] = expiry + w.expirationsPerHeight[expiry] = append( + w.expirationsPerHeight[expiry], traderKey, + ) +} From 7b537919970c3edaaff59f71d4f3a224e6d1f385 Mon Sep 17 00:00:00 2001 From: positiveblue Date: Mon, 6 Dec 2021 22:01:42 -0800 Subject: [PATCH 3/6] account/watcher: delete `expiryReq` --- account/manager.go | 10 ++-------- account/watcher/controller.go | 34 +++------------------------------- account/watcher/interfaces.go | 2 +- 3 files changed, 6 insertions(+), 40 deletions(-) diff --git a/account/manager.go b/account/manager.go index a304f7d67..bfd12565f 100644 --- a/account/manager.go +++ b/account/manager.go @@ -843,12 +843,9 @@ func (m *Manager) handleStateOpen(ctx context.Context, account *Account) error { return fmt.Errorf("unable to watch for spend: %v", err) } - err = m.watcherCtrl.WatchAccountExpiration( + m.watcherCtrl.WatchAccountExpiration( account.TraderKey.PubKey, account.Expiry, ) - if err != nil { - return fmt.Errorf("unable to watch for expiration: %v", err) - } // Now that we have an open account, subscribe for updates to it to the // server. We subscribe for the account instead of the individual orders @@ -1252,10 +1249,7 @@ func (m *Manager) RenewAccount(ctx context.Context, // Begin to track the new account expiration, which will overwrite the // existing expiration request. - err = m.watcherCtrl.WatchAccountExpiration(traderKey, modifiedAccount.Expiry) - if err != nil { - return nil, nil, err - } + m.watcherCtrl.WatchAccountExpiration(traderKey, modifiedAccount.Expiry) return modifiedAccount, spendPkg.tx, nil } diff --git a/account/watcher/controller.go b/account/watcher/controller.go index c9b37763d..19b110fc4 100644 --- a/account/watcher/controller.go +++ b/account/watcher/controller.go @@ -2,7 +2,6 @@ package watcher import ( "context" - "errors" "sync" "github.com/btcsuite/btcd/btcec" @@ -14,17 +13,7 @@ import ( "google.golang.org/grpc/status" ) -// expiryReq is an internal message we'll sumbit to the Watcher to process for -// external expiration requests. -type expiryReq struct { - // traderKey is the base trader key of the account. - traderKey *btcec.PublicKey - - // expiry is the expiry of the account as a block height. - expiry uint32 -} - -// Config contains all of the Controller's dependencies in order to carry out its +// CtrlConfig contains all of the Controller's dependencies in order to carry out its // duties. type CtrlConfig struct { // ChainNotifier is responsible for requesting confirmation and spend @@ -44,8 +33,6 @@ type controller struct { watcher ExpiryWatcher - expiryReqs chan *expiryReq - wg sync.WaitGroup quit chan struct{} ctxCancels []func() @@ -65,7 +52,6 @@ func NewController(cfg *CtrlConfig) *controller { // nolint:golint return &controller{ cfg: cfg, watcher: watcher, - expiryReqs: make(chan *expiryReq), quit: make(chan struct{}), spendCancels: make(map[[33]byte]func()), confCancels: make(map[[33]byte]func()), @@ -150,11 +136,6 @@ func (c *controller) expiryHandler(blockChan chan int32, errChan chan error) { log.Errorf("Unable to receive block notification: %v", err) - // A new watch expiry request has been received for an account. - case req := <-c.expiryReqs: - c.watcher.AddAccountExpiration( - req.traderKey, req.expiry, - ) case <-c.quit: return } @@ -322,18 +303,9 @@ func (c *controller) waitForAccountSpend(traderKey *btcec.PublicKey, // Successive calls for the same account will cancel any previous expiration // watch requests and the new expiration will be tracked instead. func (c *controller) WatchAccountExpiration(traderKey *btcec.PublicKey, - expiry uint32) error { + expiry uint32) { - select { - case c.expiryReqs <- &expiryReq{ - traderKey: traderKey, - expiry: expiry, - }: - return nil - - case <-c.quit: - return errors.New("watcher shutting down") - } + c.watcher.AddAccountExpiration(traderKey, expiry) } // CancelAccountSpend cancels the spend watcher of the given account, if one is diff --git a/account/watcher/interfaces.go b/account/watcher/interfaces.go index 01f3afec1..c4b4095f9 100644 --- a/account/watcher/interfaces.go +++ b/account/watcher/interfaces.go @@ -43,7 +43,7 @@ type Controller interface { // WatchAccountExpiration watches for the expiration of an account on-chain. // Successive calls for the same account will cancel any previous expiration // watch requests and the new expiration will be tracked instead. - WatchAccountExpiration(traderKey *btcec.PublicKey, expiry uint32) error + WatchAccountExpiration(traderKey *btcec.PublicKey, expiry uint32) } // EventHandler is the interface used by other components to handle the different From 739498543aad99a3742271e073fd3859a7837602 Mon Sep 17 00:00:00 2001 From: positiveblue Date: Wed, 1 Dec 2021 21:16:38 -0800 Subject: [PATCH 4/6] mocks: add account/watcher mocks --- account/watcher/mock_interface_test.go | 240 +++++++++++++++++++++++++ gen.go | 1 + internal/test/interfaces.go | 4 + internal/test/mock_interfaces.go | 73 ++++++++ 4 files changed, 318 insertions(+) create mode 100644 account/watcher/mock_interface_test.go diff --git a/account/watcher/mock_interface_test.go b/account/watcher/mock_interface_test.go new file mode 100644 index 000000000..3e5376599 --- /dev/null +++ b/account/watcher/mock_interface_test.go @@ -0,0 +1,240 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: account/watcher/interfaces.go + +// Package watcher is a generated GoMock package. +package watcher + +import ( + reflect "reflect" + + btcec "github.com/btcsuite/btcd/btcec" + chainhash "github.com/btcsuite/btcd/chaincfg/chainhash" + wire "github.com/btcsuite/btcd/wire" + gomock "github.com/golang/mock/gomock" + chainntnfs "github.com/lightningnetwork/lnd/chainntnfs" +) + +// MockController is a mock of Controller interface. +type MockController struct { + ctrl *gomock.Controller + recorder *MockControllerMockRecorder +} + +// MockControllerMockRecorder is the mock recorder for MockController. +type MockControllerMockRecorder struct { + mock *MockController +} + +// NewMockController creates a new mock instance. +func NewMockController(ctrl *gomock.Controller) *MockController { + mock := &MockController{ctrl: ctrl} + mock.recorder = &MockControllerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockController) EXPECT() *MockControllerMockRecorder { + return m.recorder +} + +// CancelAccountConf mocks base method. +func (m *MockController) CancelAccountConf(traderKey *btcec.PublicKey) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "CancelAccountConf", traderKey) +} + +// CancelAccountConf indicates an expected call of CancelAccountConf. +func (mr *MockControllerMockRecorder) CancelAccountConf(traderKey interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelAccountConf", reflect.TypeOf((*MockController)(nil).CancelAccountConf), traderKey) +} + +// CancelAccountSpend mocks base method. +func (m *MockController) CancelAccountSpend(traderKey *btcec.PublicKey) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "CancelAccountSpend", traderKey) +} + +// CancelAccountSpend indicates an expected call of CancelAccountSpend. +func (mr *MockControllerMockRecorder) CancelAccountSpend(traderKey interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelAccountSpend", reflect.TypeOf((*MockController)(nil).CancelAccountSpend), traderKey) +} + +// Start mocks base method. +func (m *MockController) Start() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Start") + ret0, _ := ret[0].(error) + return ret0 +} + +// Start indicates an expected call of Start. +func (mr *MockControllerMockRecorder) Start() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockController)(nil).Start)) +} + +// Stop mocks base method. +func (m *MockController) Stop() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Stop") +} + +// Stop indicates an expected call of Stop. +func (mr *MockControllerMockRecorder) Stop() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockController)(nil).Stop)) +} + +// WatchAccountConf mocks base method. +func (m *MockController) WatchAccountConf(traderKey *btcec.PublicKey, txHash chainhash.Hash, script []byte, numConfs, heightHint uint32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WatchAccountConf", traderKey, txHash, script, numConfs, heightHint) + ret0, _ := ret[0].(error) + return ret0 +} + +// WatchAccountConf indicates an expected call of WatchAccountConf. +func (mr *MockControllerMockRecorder) WatchAccountConf(traderKey, txHash, script, numConfs, heightHint interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WatchAccountConf", reflect.TypeOf((*MockController)(nil).WatchAccountConf), traderKey, txHash, script, numConfs, heightHint) +} + +// WatchAccountExpiration mocks base method. +func (m *MockController) WatchAccountExpiration(traderKey *btcec.PublicKey, expiry uint32) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "WatchAccountExpiration", traderKey, expiry) +} + +// WatchAccountExpiration indicates an expected call of WatchAccountExpiration. +func (mr *MockControllerMockRecorder) WatchAccountExpiration(traderKey, expiry interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WatchAccountExpiration", reflect.TypeOf((*MockController)(nil).WatchAccountExpiration), traderKey, expiry) +} + +// WatchAccountSpend mocks base method. +func (m *MockController) WatchAccountSpend(traderKey *btcec.PublicKey, accountPoint wire.OutPoint, script []byte, heightHint uint32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WatchAccountSpend", traderKey, accountPoint, script, heightHint) + ret0, _ := ret[0].(error) + return ret0 +} + +// WatchAccountSpend indicates an expected call of WatchAccountSpend. +func (mr *MockControllerMockRecorder) WatchAccountSpend(traderKey, accountPoint, script, heightHint interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WatchAccountSpend", reflect.TypeOf((*MockController)(nil).WatchAccountSpend), traderKey, accountPoint, script, heightHint) +} + +// MockEventHandler is a mock of EventHandler interface. +type MockEventHandler struct { + ctrl *gomock.Controller + recorder *MockEventHandlerMockRecorder +} + +// MockEventHandlerMockRecorder is the mock recorder for MockEventHandler. +type MockEventHandlerMockRecorder struct { + mock *MockEventHandler +} + +// NewMockEventHandler creates a new mock instance. +func NewMockEventHandler(ctrl *gomock.Controller) *MockEventHandler { + mock := &MockEventHandler{ctrl: ctrl} + mock.recorder = &MockEventHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockEventHandler) EXPECT() *MockEventHandlerMockRecorder { + return m.recorder +} + +// HandleAccountConf mocks base method. +func (m *MockEventHandler) HandleAccountConf(arg0 *btcec.PublicKey, arg1 *chainntnfs.TxConfirmation) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandleAccountConf", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// HandleAccountConf indicates an expected call of HandleAccountConf. +func (mr *MockEventHandlerMockRecorder) HandleAccountConf(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleAccountConf", reflect.TypeOf((*MockEventHandler)(nil).HandleAccountConf), arg0, arg1) +} + +// HandleAccountExpiry mocks base method. +func (m *MockEventHandler) HandleAccountExpiry(arg0 *btcec.PublicKey, arg1 uint32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandleAccountExpiry", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// HandleAccountExpiry indicates an expected call of HandleAccountExpiry. +func (mr *MockEventHandlerMockRecorder) HandleAccountExpiry(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleAccountExpiry", reflect.TypeOf((*MockEventHandler)(nil).HandleAccountExpiry), arg0, arg1) +} + +// HandleAccountSpend mocks base method. +func (m *MockEventHandler) HandleAccountSpend(arg0 *btcec.PublicKey, arg1 *chainntnfs.SpendDetail) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandleAccountSpend", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// HandleAccountSpend indicates an expected call of HandleAccountSpend. +func (mr *MockEventHandlerMockRecorder) HandleAccountSpend(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleAccountSpend", reflect.TypeOf((*MockEventHandler)(nil).HandleAccountSpend), arg0, arg1) +} + +// MockExpiryWatcher is a mock of ExpiryWatcher interface. +type MockExpiryWatcher struct { + ctrl *gomock.Controller + recorder *MockExpiryWatcherMockRecorder +} + +// MockExpiryWatcherMockRecorder is the mock recorder for MockExpiryWatcher. +type MockExpiryWatcherMockRecorder struct { + mock *MockExpiryWatcher +} + +// NewMockExpiryWatcher creates a new mock instance. +func NewMockExpiryWatcher(ctrl *gomock.Controller) *MockExpiryWatcher { + mock := &MockExpiryWatcher{ctrl: ctrl} + mock.recorder = &MockExpiryWatcherMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockExpiryWatcher) EXPECT() *MockExpiryWatcherMockRecorder { + return m.recorder +} + +// AddAccountExpiration mocks base method. +func (m *MockExpiryWatcher) AddAccountExpiration(traderKey *btcec.PublicKey, expiry uint32) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AddAccountExpiration", traderKey, expiry) +} + +// AddAccountExpiration indicates an expected call of AddAccountExpiration. +func (mr *MockExpiryWatcherMockRecorder) AddAccountExpiration(traderKey, expiry interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddAccountExpiration", reflect.TypeOf((*MockExpiryWatcher)(nil).AddAccountExpiration), traderKey, expiry) +} + +// NewBlock mocks base method. +func (m *MockExpiryWatcher) NewBlock(bestHeight uint32) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "NewBlock", bestHeight) +} + +// NewBlock indicates an expected call of NewBlock. +func (mr *MockExpiryWatcherMockRecorder) NewBlock(bestHeight interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewBlock", reflect.TypeOf((*MockExpiryWatcher)(nil).NewBlock), bestHeight) +} diff --git a/gen.go b/gen.go index fa42ee4f0..2f72aa6e5 100644 --- a/gen.go +++ b/gen.go @@ -7,3 +7,4 @@ package pool //go:generate mockgen -source=sidecar/interfaces.go -package=sidecar -destination=sidecar/mock_interfaces.go //go:generate mockgen -source=internal/test/interfaces.go -package=test -destination=internal/test/mock_interfaces.go +//go:generate mockgen -source=account/watcher/interfaces.go -package=watcher -destination=account/watcher/mock_interface_test.go diff --git a/internal/test/interfaces.go b/internal/test/interfaces.go index 30c5b3633..1c773474e 100644 --- a/internal/test/interfaces.go +++ b/internal/test/interfaces.go @@ -9,3 +9,7 @@ type SignerClient interface { type WalletKitClient interface { lndclient.WalletKitClient } + +type ChainNotifierClient interface { + lndclient.ChainNotifierClient +} diff --git a/internal/test/mock_interfaces.go b/internal/test/mock_interfaces.go index 244b824bd..482f6a6b3 100644 --- a/internal/test/mock_interfaces.go +++ b/internal/test/mock_interfaces.go @@ -10,11 +10,13 @@ import ( time "time" btcec "github.com/btcsuite/btcd/btcec" + chainhash "github.com/btcsuite/btcd/chaincfg/chainhash" wire "github.com/btcsuite/btcd/wire" btcutil "github.com/btcsuite/btcutil" wtxmgr "github.com/btcsuite/btcwallet/wtxmgr" gomock "github.com/golang/mock/gomock" lndclient "github.com/lightninglabs/lndclient" + chainntnfs "github.com/lightningnetwork/lnd/chainntnfs" input "github.com/lightningnetwork/lnd/input" keychain "github.com/lightningnetwork/lnd/keychain" walletrpc "github.com/lightningnetwork/lnd/lnrpc/walletrpc" @@ -319,3 +321,74 @@ func (mr *MockWalletKitClientMockRecorder) SendOutputs(ctx, outputs, feeRate, la mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendOutputs", reflect.TypeOf((*MockWalletKitClient)(nil).SendOutputs), ctx, outputs, feeRate, label) } + +// MockChainNotifierClient is a mock of ChainNotifierClient interface. +type MockChainNotifierClient struct { + ctrl *gomock.Controller + recorder *MockChainNotifierClientMockRecorder +} + +// MockChainNotifierClientMockRecorder is the mock recorder for MockChainNotifierClient. +type MockChainNotifierClientMockRecorder struct { + mock *MockChainNotifierClient +} + +// NewMockChainNotifierClient creates a new mock instance. +func NewMockChainNotifierClient(ctrl *gomock.Controller) *MockChainNotifierClient { + mock := &MockChainNotifierClient{ctrl: ctrl} + mock.recorder = &MockChainNotifierClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockChainNotifierClient) EXPECT() *MockChainNotifierClientMockRecorder { + return m.recorder +} + +// RegisterBlockEpochNtfn mocks base method. +func (m *MockChainNotifierClient) RegisterBlockEpochNtfn(ctx context.Context) (chan int32, chan error, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterBlockEpochNtfn", ctx) + ret0, _ := ret[0].(chan int32) + ret1, _ := ret[1].(chan error) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// RegisterBlockEpochNtfn indicates an expected call of RegisterBlockEpochNtfn. +func (mr *MockChainNotifierClientMockRecorder) RegisterBlockEpochNtfn(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterBlockEpochNtfn", reflect.TypeOf((*MockChainNotifierClient)(nil).RegisterBlockEpochNtfn), ctx) +} + +// RegisterConfirmationsNtfn mocks base method. +func (m *MockChainNotifierClient) RegisterConfirmationsNtfn(ctx context.Context, txid *chainhash.Hash, pkScript []byte, numConfs, heightHint int32) (chan *chainntnfs.TxConfirmation, chan error, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterConfirmationsNtfn", ctx, txid, pkScript, numConfs, heightHint) + ret0, _ := ret[0].(chan *chainntnfs.TxConfirmation) + ret1, _ := ret[1].(chan error) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// RegisterConfirmationsNtfn indicates an expected call of RegisterConfirmationsNtfn. +func (mr *MockChainNotifierClientMockRecorder) RegisterConfirmationsNtfn(ctx, txid, pkScript, numConfs, heightHint interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterConfirmationsNtfn", reflect.TypeOf((*MockChainNotifierClient)(nil).RegisterConfirmationsNtfn), ctx, txid, pkScript, numConfs, heightHint) +} + +// RegisterSpendNtfn mocks base method. +func (m *MockChainNotifierClient) RegisterSpendNtfn(ctx context.Context, outpoint *wire.OutPoint, pkScript []byte, heightHint int32) (chan *chainntnfs.SpendDetail, chan error, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterSpendNtfn", ctx, outpoint, pkScript, heightHint) + ret0, _ := ret[0].(chan *chainntnfs.SpendDetail) + ret1, _ := ret[1].(chan error) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// RegisterSpendNtfn indicates an expected call of RegisterSpendNtfn. +func (mr *MockChainNotifierClientMockRecorder) RegisterSpendNtfn(ctx, outpoint, pkScript, heightHint interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterSpendNtfn", reflect.TypeOf((*MockChainNotifierClient)(nil).RegisterSpendNtfn), ctx, outpoint, pkScript, heightHint) +} From 17e77c380f230c1757738e7f9b2b10c748baa83c Mon Sep 17 00:00:00 2001 From: positiveblue Date: Wed, 1 Dec 2021 21:56:32 -0800 Subject: [PATCH 5/6] account/watcher: add controller tests --- account/watcher/controller_test.go | 346 +++++++++++++++++++++++++++++ 1 file changed, 346 insertions(+) create mode 100644 account/watcher/controller_test.go diff --git a/account/watcher/controller_test.go b/account/watcher/controller_test.go new file mode 100644 index 000000000..4008e7ab6 --- /dev/null +++ b/account/watcher/controller_test.go @@ -0,0 +1,346 @@ +package watcher + +import ( + "crypto/rand" + "encoding/hex" + "errors" + "testing" + "time" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + gomock "github.com/golang/mock/gomock" + "github.com/lightninglabs/pool/internal/test" + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + errCtrlExpected = errors.New("random error") +) + +var controllerLifeCycleTestCases = []struct { + name string + mockSetter func(mockChainNotifier *test.MockChainNotifierClient) + expectedError string +}{{ + name: "we are able to start and stop the watcher " + + "successfully", + mockSetter: func(mockChainNotifier *test.MockChainNotifierClient) { + blockChan := make(chan int32) + errChan := make(chan error) + mockChainNotifier.EXPECT(). + RegisterBlockEpochNtfn(gomock.Any()). + Return(blockChan, errChan, nil) + }, + expectedError: "", +}, { + name: "unable to start watcher because of " + + "RegisterBlockEpochNtfn register error", + mockSetter: func(mockChainNotifier *test.MockChainNotifierClient) { + blockChan := make(chan int32) + errChan := make(chan error) + mockChainNotifier.EXPECT(). + RegisterBlockEpochNtfn(gomock.Any()). + Return( + blockChan, + errChan, + errCtrlExpected, + ) + }, + expectedError: errCtrlExpected.Error(), +}} + +func TestWatcherControllerLifeCycle(t *testing.T) { + for _, tc := range controllerLifeCycleTestCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + chainNotifier := test.NewMockChainNotifierClient( + mockCtrl, + ) + + tc.mockSetter(chainNotifier) + + cfg := &CtrlConfig{ + ChainNotifier: chainNotifier, + } + + watcherController := NewController(cfg) + + err := watcherController.Start() + if tc.expectedError != "" { + assert.EqualError(t, err, tc.expectedError) + return + } + require.NoError(t, err) + + watcherController.Stop() + + select { + case <-watcherController.quit: + return + case <-time.After(2 * time.Second): + t.Error("watcher controller not closed on time") + } + }) + } +} + +var controllerNewBlocksTestCases = []struct { + name string + blocks []int32 +}{{ + name: "every time that we receive a new block we update" + + "our bestHeight and look for overdue expirations", + blocks: []int32{1, 2, 3}, +}} + +func TestWatcherControllerNewBlocks(t *testing.T) { + for _, tc := range controllerNewBlocksTestCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + blockChan := make(chan int32) + errChan := make(chan error) + chainNotifier := test.NewMockChainNotifierClient( + mockCtrl, + ) + + chainNotifier.EXPECT(). + RegisterBlockEpochNtfn(gomock.Any()). + Return(blockChan, errChan, nil) + + watcher := NewMockExpiryWatcher(mockCtrl) + + for _, block := range tc.blocks { + watcher.EXPECT(). + NewBlock(uint32(block)) + } + + cfg := &CtrlConfig{ + ChainNotifier: chainNotifier, + } + + watcherController := NewController(cfg) + watcherController.watcher = watcher + + err := watcherController.Start() + require.NoError(t, err) + + for _, block := range tc.blocks { + blockChan <- block + } + + watcherController.Stop() + + select { + case <-watcherController.quit: + return + case <-time.After(2 * time.Second): + t.Error("new blocks not processed on time") + } + }) + } +} + +var controllerWatchAccountTestCases = []struct { + name string + expectedErr string +}{{ + name: "Watch account happy path", + // TODO (positiveblue): add tests for `cancel` logic +}} + +func TestWatcherControllerWatchAccount(t *testing.T) { + traderKeyStr := "036b51e0cc2d9e5988ee4967e0ba67ef3727bb633fea21a0af58e0c9395446ba09" + traderKeyRaw, _ := hex.DecodeString(traderKeyStr) + traderKey, _ := btcec.ParsePubKey(traderKeyRaw, btcec.S256()) + + var txHash chainhash.Hash + if _, err := rand.Read(txHash[:]); err != nil { // nolint:gosec + t.Error("unable to create random hash") + } + + script := make([]byte, 64) + if _, err := rand.Read(script); err != nil { // nolint:gosec + t.Error("unable to create random hash") + } + + numConfs := uint32(6) + heightHint := uint32(8) + + for _, tc := range controllerWatchAccountTestCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + confChan := make(chan *chainntnfs.TxConfirmation) + errChan := make(chan error) + + doneChan := make(chan struct{}) + + chainNotifier := test.NewMockChainNotifierClient( + mockCtrl, + ) + + chainNotifier.EXPECT(). + RegisterBlockEpochNtfn(gomock.Any()) + + chainNotifier.EXPECT(). + RegisterConfirmationsNtfn( + gomock.Any(), &txHash, script, + int32(numConfs), int32(heightHint), + ). + Return(confChan, errChan, nil) + + confirmation := &chainntnfs.TxConfirmation{} + + eventHanlers := NewMockEventHandler(mockCtrl) + + eventHanlers.EXPECT(). + HandleAccountConf(traderKey, confirmation). + Return(nil). + Do(func(_ *btcec.PublicKey, + _ *chainntnfs.TxConfirmation) { + + // Close the channel so we signal the + // test that this function was executed + close(doneChan) + }) + + cfg := &CtrlConfig{ + ChainNotifier: chainNotifier, + Handlers: eventHanlers, + } + + watcherController := NewController(cfg) + + err := watcherController.Start() + require.NoError(t, err) + + err = watcherController.WatchAccountConf( + traderKey, txHash, script, numConfs, heightHint, + ) + require.NoError(t, err) + + confChan <- confirmation + + select { + case <-doneChan: + return + case <-time.After(2 * time.Second): + t.Error("confirmation not processed on time") + } + }) + + } +} + +var controllerWatchAccountSpendTestCases = []struct { + name string + expectedErr string +}{{ + name: "Watch account spend happy path", + // TODO (positiveblue): add tests for `cancel` logic +}} + +func TestWatcherControllerWatchAccountSpend(t *testing.T) { + traderKeyStr := "036b51e0cc2d9e5988ee4967e0ba67ef3727bb633fea21a0af58e0c9395446ba09" + traderKeyRaw, _ := hex.DecodeString(traderKeyStr) + traderKey, _ := btcec.ParsePubKey(traderKeyRaw, btcec.S256()) + + outpoint := wire.OutPoint{} + + script := make([]byte, 64) + if _, err := rand.Read(script); err != nil { // nolint:gosec + t.Error("unable to create random hash") + } + + heightHint := uint32(8) + + for _, tc := range controllerWatchAccountSpendTestCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + spendChan := make(chan *chainntnfs.SpendDetail) + errChan := make(chan error) + + doneChan := make(chan struct{}) + + chainNotifier := test.NewMockChainNotifierClient( + mockCtrl, + ) + + chainNotifier.EXPECT(). + RegisterBlockEpochNtfn(gomock.Any()) + + chainNotifier.EXPECT(). + RegisterSpendNtfn( + gomock.Any(), &outpoint, script, + int32(heightHint), + ). + Return(spendChan, errChan, nil) + + handlers := NewMockEventHandler(mockCtrl) + + spendDetails := &chainntnfs.SpendDetail{} + + handlers.EXPECT(). + HandleAccountSpend(traderKey, spendDetails). + Return(nil). + Do(func(_ *btcec.PublicKey, + _ *chainntnfs.SpendDetail) { + + // Close the channel so we signal the + // test that this function was executed + close(doneChan) + }) + + cfg := &CtrlConfig{ + ChainNotifier: chainNotifier, + Handlers: handlers, + } + + watcherController := NewController(cfg) + + err := watcherController.Start() + require.NoError(t, err) + + err = watcherController.WatchAccountSpend( + traderKey, outpoint, script, heightHint, + ) + require.NoError(t, err) + + spendChan <- spendDetails + + select { + case <-doneChan: + return + case <-time.After(2 * time.Second): + t.Error("spend not processed on time") + } + }) + } +} From 1856f798fc68555cc9619606fa7de02cfe5e2854 Mon Sep 17 00:00:00 2001 From: positiveblue Date: Wed, 1 Dec 2021 21:32:03 -0800 Subject: [PATCH 6/6] account/watcher: add ExpiryWatcher tests --- account/watcher/watcher.go | 1 + account/watcher/watcher_test.go | 248 ++++++++++++++++++++++++++++++++ 2 files changed, 249 insertions(+) create mode 100644 account/watcher/watcher_test.go diff --git a/account/watcher/watcher.go b/account/watcher/watcher.go index 41df56b8b..9f5d902cd 100644 --- a/account/watcher/watcher.go +++ b/account/watcher/watcher.go @@ -67,6 +67,7 @@ func (w *expiryWatcher) overdueExpirations(blockHeight uint32) { traderKey.SerializeCompressed(), err) } + delete(w.expirations, accountKey) } delete(w.expirationsPerHeight, blockHeight) diff --git a/account/watcher/watcher_test.go b/account/watcher/watcher_test.go new file mode 100644 index 000000000..d457be9e7 --- /dev/null +++ b/account/watcher/watcher_test.go @@ -0,0 +1,248 @@ +package watcher + +import ( + "crypto/ecdsa" + "errors" + "math/rand" + "testing" + "time" + + "github.com/btcsuite/btcd/btcec" + gomock "github.com/golang/mock/gomock" +) + +func randomPrivateKey(seed int64) *btcec.PrivateKey { + r := rand.New(rand.NewSource(seed)) + key, err := ecdsa.GenerateKey(btcec.S256(), r) + if err != nil { + return nil + } + return (*btcec.PrivateKey)(key) +} + +func randomPublicKey(seed int64) *btcec.PublicKey { + key := randomPrivateKey(seed) + return key.PubKey() +} + +func randomAccountKey(seed int64) [33]byte { + var accountKey [33]byte + + key := randomPublicKey(seed) + copy(accountKey[:], key.SerializeCompressed()) + return accountKey +} + +var overdueExpirationsTestCases = []struct { + name string + blockHeight uint32 + expirations map[[33]byte]uint32 + expirationsPerHeight map[uint32][]*btcec.PublicKey + handledExpirations []*btcec.PublicKey + checks []func(watcher *expiryWatcher) error +}{{ + name: "overdue expirations are handled properly", + blockHeight: 24, + expirations: map[[33]byte]uint32{ + randomAccountKey(0): 24, + randomAccountKey(1): 24, + randomAccountKey(2): 24, + randomAccountKey(3): 27, + }, + handledExpirations: []*btcec.PublicKey{ + randomPublicKey(0), + randomPublicKey(1), + randomPublicKey(2), + }, + expirationsPerHeight: map[uint32][]*btcec.PublicKey{ + 24: { + randomPublicKey(0), + randomPublicKey(1), + randomPublicKey(2), + }, + 27: { + randomPublicKey(27), + }, + }, + checks: []func(watcher *expiryWatcher) error{ + func(watcher *expiryWatcher) error { + left := watcher.expirationsPerHeight[24] + if len(left) != 0 { + return errors.New( + "expirations were not " + + "handled properly", + ) + } + return nil + }, + func(watcher *expiryWatcher) error { + if len(watcher.expirations) != 1 { + return errors.New( + "handled expirations were " + + " not deleted", + ) + } + return nil + }, + }, +}, { + name: "if account wasn't track we ignore it", + blockHeight: 24, + expirationsPerHeight: map[uint32][]*btcec.PublicKey{ + 24: { + randomPublicKey(3), + }, + }, + checks: []func(watcher *expiryWatcher) error{}, +}} + +func TestOverdueExpirations(t *testing.T) { + for _, tc := range overdueExpirationsTestCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + handlers := NewMockEventHandler(mockCtrl) + watcher := NewExpiryWatcher(handlers) + watcher.expirations = tc.expirations + watcher.expirationsPerHeight = tc.expirationsPerHeight + + for _, trader := range tc.handledExpirations { + handlers.EXPECT(). + HandleAccountExpiry( + trader, + tc.blockHeight, + ). + Return(nil) + } + + watcher.NewBlock(tc.blockHeight) + + for _, check := range tc.checks { + if err := check(watcher); err != nil { + t.Fatal(err) + } + } + }) + } +} + +var addAccountExpirationTestCases = []struct { + name string + bestHeight uint32 + initialExpirations map[[33]byte]uint32 + expirations map[*btcec.PublicKey]uint32 + handler func(*btcec.PublicKey, uint32) error + checks []func(watcher *expiryWatcher) error +}{{ + name: "account is tracked happy path", + bestHeight: 20, + expirations: map[*btcec.PublicKey]uint32{ + randomPublicKey(1): 25, + randomPublicKey(2): 25, + randomPublicKey(3): 25, + }, + checks: []func(watcher *expiryWatcher) error{ + func(watcher *expiryWatcher) error { + if len(watcher.expirations) != 3 { + return errors.New( + "account expiry not added", + ) + } + return nil + }, + }, +}, { + name: "account with earlier expiry are directly handled", + bestHeight: 20, + expirations: map[*btcec.PublicKey]uint32{ + randomPublicKey(1): 19, + }, + handler: func(*btcec.PublicKey, uint32) error { + return nil + }, + checks: []func(watcher *expiryWatcher) error{ + func(watcher *expiryWatcher) error { + if len(watcher.expirations) != 0 { + return errors.New("an account with " + + "older expiry hight was added") + } + return nil + }, + }, +}, { + name: "adding an account that we are already watching", + bestHeight: 20, + initialExpirations: map[[33]byte]uint32{ + randomAccountKey(1): 25, + }, + expirations: map[*btcec.PublicKey]uint32{ + randomPublicKey(1): 35, + }, + handler: func(*btcec.PublicKey, uint32) error { + return nil + }, + checks: []func(watcher *expiryWatcher) error{ + func(watcher *expiryWatcher) error { + msg := "account expiry was not updated" + if len(watcher.expirationsPerHeight[35]) != 1 { + return errors.New(msg) + } + + if watcher.expirations[randomAccountKey(1)] != 35 { + return errors.New(msg) + } + return nil + }, + }, +}} + +func TestAddAccountExpiration(t *testing.T) { + for _, tc := range addAccountExpirationTestCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + handlers := NewMockEventHandler(mockCtrl) + watcher := NewExpiryWatcher(handlers) + + if len(tc.initialExpirations) > 0 { + watcher.expirations = tc.initialExpirations + } + watcher.bestHeight = tc.bestHeight + + for trader, height := range tc.expirations { + if height < tc.bestHeight { + handlers.EXPECT(). + HandleAccountExpiry( + trader, + tc.bestHeight, + ). + Return(nil) + } + + watcher.AddAccountExpiration(trader, height) + } + + // The HandleAccountExpiry is executed in the background + // give it some time to ensure that the goroutine has time + // to get executed. This could potentially trigger + // false test failures. + time.Sleep(500 * time.Millisecond) + + for _, check := range tc.checks { + if err := check(watcher); err != nil { + t.Fatal(err) + } + } + }) + } +}