-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
account/watcher: add ExpiryWatcher tests
- Loading branch information
1 parent
17e77c3
commit 1856f79
Showing
2 changed files
with
249 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,248 @@ | ||
package watcher | ||
|
||
import ( | ||
"crypto/ecdsa" | ||
"errors" | ||
"math/rand" | ||
"testing" | ||
"time" | ||
|
||
"github.com/btcsuite/btcd/btcec" | ||
gomock "github.com/golang/mock/gomock" | ||
) | ||
|
||
func randomPrivateKey(seed int64) *btcec.PrivateKey { | ||
r := rand.New(rand.NewSource(seed)) | ||
key, err := ecdsa.GenerateKey(btcec.S256(), r) | ||
if err != nil { | ||
return nil | ||
} | ||
return (*btcec.PrivateKey)(key) | ||
} | ||
|
||
func randomPublicKey(seed int64) *btcec.PublicKey { | ||
key := randomPrivateKey(seed) | ||
return key.PubKey() | ||
} | ||
|
||
func randomAccountKey(seed int64) [33]byte { | ||
var accountKey [33]byte | ||
|
||
key := randomPublicKey(seed) | ||
copy(accountKey[:], key.SerializeCompressed()) | ||
return accountKey | ||
} | ||
|
||
var overdueExpirationsTestCases = []struct { | ||
name string | ||
blockHeight uint32 | ||
expirations map[[33]byte]uint32 | ||
expirationsPerHeight map[uint32][]*btcec.PublicKey | ||
handledExpirations []*btcec.PublicKey | ||
checks []func(watcher *expiryWatcher) error | ||
}{{ | ||
name: "overdue expirations are handled properly", | ||
blockHeight: 24, | ||
expirations: map[[33]byte]uint32{ | ||
randomAccountKey(0): 24, | ||
randomAccountKey(1): 24, | ||
randomAccountKey(2): 24, | ||
randomAccountKey(3): 27, | ||
}, | ||
handledExpirations: []*btcec.PublicKey{ | ||
randomPublicKey(0), | ||
randomPublicKey(1), | ||
randomPublicKey(2), | ||
}, | ||
expirationsPerHeight: map[uint32][]*btcec.PublicKey{ | ||
24: { | ||
randomPublicKey(0), | ||
randomPublicKey(1), | ||
randomPublicKey(2), | ||
}, | ||
27: { | ||
randomPublicKey(27), | ||
}, | ||
}, | ||
checks: []func(watcher *expiryWatcher) error{ | ||
func(watcher *expiryWatcher) error { | ||
left := watcher.expirationsPerHeight[24] | ||
if len(left) != 0 { | ||
return errors.New( | ||
"expirations were not " + | ||
"handled properly", | ||
) | ||
} | ||
return nil | ||
}, | ||
func(watcher *expiryWatcher) error { | ||
if len(watcher.expirations) != 1 { | ||
return errors.New( | ||
"handled expirations were " + | ||
" not deleted", | ||
) | ||
} | ||
return nil | ||
}, | ||
}, | ||
}, { | ||
name: "if account wasn't track we ignore it", | ||
blockHeight: 24, | ||
expirationsPerHeight: map[uint32][]*btcec.PublicKey{ | ||
24: { | ||
randomPublicKey(3), | ||
}, | ||
}, | ||
checks: []func(watcher *expiryWatcher) error{}, | ||
}} | ||
|
||
func TestOverdueExpirations(t *testing.T) { | ||
for _, tc := range overdueExpirationsTestCases { | ||
tc := tc | ||
|
||
t.Run(tc.name, func(t *testing.T) { | ||
t.Parallel() | ||
|
||
mockCtrl := gomock.NewController(t) | ||
defer mockCtrl.Finish() | ||
|
||
handlers := NewMockEventHandler(mockCtrl) | ||
watcher := NewExpiryWatcher(handlers) | ||
watcher.expirations = tc.expirations | ||
watcher.expirationsPerHeight = tc.expirationsPerHeight | ||
|
||
for _, trader := range tc.handledExpirations { | ||
handlers.EXPECT(). | ||
HandleAccountExpiry( | ||
trader, | ||
tc.blockHeight, | ||
). | ||
Return(nil) | ||
} | ||
|
||
watcher.NewBlock(tc.blockHeight) | ||
|
||
for _, check := range tc.checks { | ||
if err := check(watcher); err != nil { | ||
t.Fatal(err) | ||
} | ||
} | ||
}) | ||
} | ||
} | ||
|
||
var addAccountExpirationTestCases = []struct { | ||
name string | ||
bestHeight uint32 | ||
initialExpirations map[[33]byte]uint32 | ||
expirations map[*btcec.PublicKey]uint32 | ||
handler func(*btcec.PublicKey, uint32) error | ||
checks []func(watcher *expiryWatcher) error | ||
}{{ | ||
name: "account is tracked happy path", | ||
bestHeight: 20, | ||
expirations: map[*btcec.PublicKey]uint32{ | ||
randomPublicKey(1): 25, | ||
randomPublicKey(2): 25, | ||
randomPublicKey(3): 25, | ||
}, | ||
checks: []func(watcher *expiryWatcher) error{ | ||
func(watcher *expiryWatcher) error { | ||
if len(watcher.expirations) != 3 { | ||
return errors.New( | ||
"account expiry not added", | ||
) | ||
} | ||
return nil | ||
}, | ||
}, | ||
}, { | ||
name: "account with earlier expiry are directly handled", | ||
bestHeight: 20, | ||
expirations: map[*btcec.PublicKey]uint32{ | ||
randomPublicKey(1): 19, | ||
}, | ||
handler: func(*btcec.PublicKey, uint32) error { | ||
return nil | ||
}, | ||
checks: []func(watcher *expiryWatcher) error{ | ||
func(watcher *expiryWatcher) error { | ||
if len(watcher.expirations) != 0 { | ||
return errors.New("an account with " + | ||
"older expiry hight was added") | ||
} | ||
return nil | ||
}, | ||
}, | ||
}, { | ||
name: "adding an account that we are already watching", | ||
bestHeight: 20, | ||
initialExpirations: map[[33]byte]uint32{ | ||
randomAccountKey(1): 25, | ||
}, | ||
expirations: map[*btcec.PublicKey]uint32{ | ||
randomPublicKey(1): 35, | ||
}, | ||
handler: func(*btcec.PublicKey, uint32) error { | ||
return nil | ||
}, | ||
checks: []func(watcher *expiryWatcher) error{ | ||
func(watcher *expiryWatcher) error { | ||
msg := "account expiry was not updated" | ||
if len(watcher.expirationsPerHeight[35]) != 1 { | ||
return errors.New(msg) | ||
} | ||
|
||
if watcher.expirations[randomAccountKey(1)] != 35 { | ||
return errors.New(msg) | ||
} | ||
return nil | ||
}, | ||
}, | ||
}} | ||
|
||
func TestAddAccountExpiration(t *testing.T) { | ||
for _, tc := range addAccountExpirationTestCases { | ||
tc := tc | ||
|
||
t.Run(tc.name, func(t *testing.T) { | ||
t.Parallel() | ||
|
||
mockCtrl := gomock.NewController(t) | ||
defer mockCtrl.Finish() | ||
|
||
handlers := NewMockEventHandler(mockCtrl) | ||
watcher := NewExpiryWatcher(handlers) | ||
|
||
if len(tc.initialExpirations) > 0 { | ||
watcher.expirations = tc.initialExpirations | ||
} | ||
watcher.bestHeight = tc.bestHeight | ||
|
||
for trader, height := range tc.expirations { | ||
if height < tc.bestHeight { | ||
handlers.EXPECT(). | ||
HandleAccountExpiry( | ||
trader, | ||
tc.bestHeight, | ||
). | ||
Return(nil) | ||
} | ||
|
||
watcher.AddAccountExpiration(trader, height) | ||
} | ||
|
||
// The HandleAccountExpiry is executed in the background | ||
// give it some time to ensure that the goroutine has time | ||
// to get executed. This could potentially trigger | ||
// false test failures. | ||
time.Sleep(500 * time.Millisecond) | ||
|
||
for _, check := range tc.checks { | ||
if err := check(watcher); err != nil { | ||
t.Fatal(err) | ||
} | ||
} | ||
}) | ||
} | ||
} |