diff --git a/CHANGELOG.md b/CHANGELOG.md index db41a1eda64a..600d0e18299b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -264,6 +264,7 @@ Every module contains its own CHANGELOG.md. Please refer to the module you are i * (baseapp) [#20107](https://github.com/cosmos/cosmos-sdk/pull/20107) Avoid header height overwrite block height. * (cli) [#20020](https://github.com/cosmos/cosmos-sdk/pull/20020) Make bootstrap-state command support both new and legacy genesis format. * (testutil/sims) [#20151](https://github.com/cosmos/cosmos-sdk/pull/20151) Set all signatures and don't overwrite the previous one in `GenSignedMockTx`. +* (mempool) [#21379](https://github.com/cosmos/cosmos-sdk/pull/21379) Avoid concurrent map read and map write in priority nonce mempool. ## [v0.50.6](https://github.com/cosmos/cosmos-sdk/releases/tag/v0.50.6) - 2024-04-22 diff --git a/scripts/build/localnet.mk b/scripts/build/localnet.mk index 931e8b421285..47c383dc282c 100644 --- a/scripts/build/localnet.mk +++ b/scripts/build/localnet.mk @@ -8,11 +8,11 @@ localnet-build-dlv: localnet-build-nodes: $(DOCKER) run --rm -v $(CURDIR)/.testnets:/data cosmossdk/simd \ testnet init-files -n 4 -o /data --starting-ip-address 192.168.10.2 --keyring-backend=test --listen-ip-address 0.0.0.0 - docker-compose up -d + docker compose up -d #? localnet-stop: Stop localnet node localnet-stop: - docker-compose down + docker compose down # localnet-start will run a 4-node testnet locally. The nodes are # based off the docker images in: ./contrib/images/simd-env diff --git a/types/mempool/priority_index.go b/types/mempool/priority_index.go new file mode 100644 index 000000000000..e10a5f513a1e --- /dev/null +++ b/types/mempool/priority_index.go @@ -0,0 +1,188 @@ +package mempool + +import ( + "sync" + + "github.com/huandu/skiplist" +) + +// ConcurrentListElement represents a node in a concurrent priority index, +// encapsulating a skiplist element and a read-write mutex for safe concurrent access. +type ConcurrentListElement struct { + *skiplist.Element + mutex *sync.RWMutex + Value interface{} +} + +// Next safely retrieves the next node in the priority index. +// It acquires a read lock before accessing the next element and releases it afterward. +func (n *ConcurrentListElement) Next() *ConcurrentListElement { + n.mutex.RLock() + defer n.mutex.RUnlock() + ele := n.Element.Next() + if ele == nil { + return nil + } + return &ConcurrentListElement{ + Element: ele, + mutex: n.mutex, + Value: ele.Value, + } +} + +type scoreKey struct { + nonce uint64 + sender string +} + +type score[C comparable] struct { + Priority C + Weight C +} + +// ConcurrentSkipList represents a concurrent priority index, +// containing a skiplist and a map to track priority counts. +type ConcurrentSkipList[C comparable] struct { + mutex sync.RWMutex + list *skiplist.SkipList + priorityCounts map[C]int + scores map[scoreKey]score[C] +} + +// newConcurrentPriorityIndex initializes a new ConcurrentPriorityIndex. +// It accepts a Comparable for the skiplist and a boolean to determine if counts should be tracked. +func newConcurrentPriorityIndex[C comparable]( + listComparable skiplist.Comparable, + priority bool, +) *ConcurrentSkipList[C] { + i := &ConcurrentSkipList[C]{ + list: skiplist.New(listComparable), + } + if priority { + i.priorityCounts = make(map[C]int) + i.scores = make(map[scoreKey]score[C]) + } + return i +} + +// Len returns the number of elements in the priority index. +// It locks the list for reading to ensure safe access. +func (i *ConcurrentSkipList[C]) Len() int { + i.mutex.RLock() + defer i.mutex.RUnlock() + return i.list.Len() +} + +// Front retrieves the first node in the priority index. +// It locks the list for reading to ensure safe access. +func (i *ConcurrentSkipList[C]) Front() *ConcurrentListElement { + i.mutex.RLock() + defer i.mutex.RUnlock() + ele := i.list.Front() + if ele == nil { + return nil + } + return &ConcurrentListElement{ + Element: ele, + mutex: &i.mutex, + Value: ele.Value, + } +} + +// GetCount retrieves the count of a specific key from the priority counts map. +// It locks priorityCounts for reading to ensure safe access. +func (i *ConcurrentSkipList[C]) GetCount(key C) int { + i.mutex.RLock() + defer i.mutex.RUnlock() + if i.priorityCounts == nil { + return -1 + } + return i.priorityCounts[key] +} + +// CloneCounts creates a copy of the priority counts map. +// It locks priorityCounts for reading to ensure safe access. +func (i *ConcurrentSkipList[C]) CloneCounts() map[C]int { + i.mutex.RLock() + defer i.mutex.RUnlock() + if i.priorityCounts == nil { + return nil + } + + counts := make(map[C]int) + for k, v := range i.priorityCounts { + counts[k] = v + } + return counts +} + +// GetScore retrieves the score associated with a specific nonce and sender. +// It returns a pointer to the score if found, or nil if not found. +// It locks the scores for reading to ensure safe access. +func (i *ConcurrentSkipList[C]) GetScore(nonce uint64, sender string) *score[C] { //revive:disable:unexported-return + i.mutex.RLock() + defer i.mutex.RUnlock() + score, ok := i.scores[scoreKey{nonce: nonce, sender: sender}] + if !ok { + return nil + } + return &score +} + +// Get retrieves a node corresponding to a specific key from the priority index. +// It locks the list for reading to ensure safe access. +func (i *ConcurrentSkipList[C]) Get(key txMeta[C]) *ConcurrentListElement { + i.mutex.RLock() + defer i.mutex.RUnlock() + ele := i.list.Get(key) + if ele == nil { + return nil + } + return &ConcurrentListElement{ + Element: ele, + mutex: &i.mutex, + Value: ele.Value, + } +} + +// Set inserts or updates a node in the priority index with the given key and value. +// It locks priorityCounts, scores and list for writing to ensure safe access and updates the priority count. +func (i *ConcurrentSkipList[C]) Set(key txMeta[C], value any) *ConcurrentListElement { + i.mutex.Lock() + defer i.mutex.Unlock() + if i.priorityCounts != nil { + i.priorityCounts[key.priority]++ + } + if i.scores != nil { + i.scores[scoreKey{ + nonce: key.nonce, + sender: key.sender, + }] = score[C]{ + Priority: key.priority, + Weight: key.weight, + } + } + ele := i.list.Set(key, value) + if ele == nil { + return nil + } + return &ConcurrentListElement{ + Element: ele, + mutex: &i.mutex, + Value: ele.Value, + } +} + +// Remove deletes a node from the priority index using the specified key. +// It locks priorityCounts and scores for writing to ensure safe access and decrements the priority count. +func (i *ConcurrentSkipList[C]) Remove(key txMeta[C]) { + i.mutex.Lock() + defer i.mutex.Unlock() + if i.priorityCounts != nil { + i.priorityCounts[key.priority]-- + } + if i.scores != nil { + delete(i.scores, scoreKey{nonce: key.nonce, sender: key.sender}) + } + i.list.Remove(key) +} diff --git a/types/mempool/priority_index_test.go b/types/mempool/priority_index_test.go new file mode 100644 index 000000000000..621d08887a69 --- /dev/null +++ b/types/mempool/priority_index_test.go @@ -0,0 +1,263 @@ +package mempool + +import ( + "sync" + "testing" + + "github.com/huandu/skiplist" + "github.com/stretchr/testify/require" +) + +func TestConcurrentPriorityNode_Next(t *testing.T) { + list := skiplist.New(skiplist.Int64) + mutex := &sync.RWMutex{} + + for i := 0; i < 5; i++ { + list.Set(int64(i), int64(i)) + } + + firstEle := list.Front() + require.NotNil(t, firstEle) + node := &ConcurrentListElement{ + Element: firstEle, + mutex: mutex, + } + + for i := 0; i < 5; i++ { + require.NotNil(t, node) + nextNode := node.Next() + if nextNode != nil { + expected := int64(i + 1) + require.Equal(t, expected, nextNode.Value) + } + node = nextNode + } + + // expected node to be nil after traversing all elements + require.Nil(t, node) +} + +func TestConcurrentPriorityIndex_Len(t *testing.T) { + index := newConcurrentPriorityIndex[int](skiplist.LessThanFunc(func(a, b any) int { + return skiplist.Uint64.Compare(b.(txMeta[int]).nonce, a.(txMeta[int]).nonce) + }), true) + + total := 5 + for i := 0; i < total; i++ { + index.Set(txMeta[int]{nonce: uint64(i)}, i) + } + + require.Equal(t, total, index.Len()) +} + +func TestConcurrentPriorityIndex_Front(t *testing.T) { + index := newConcurrentPriorityIndex[int](skiplist.LessThanFunc(func(a, b any) int { + return skiplist.Int.Compare(b.(txMeta[int]).priority, a.(txMeta[int]).priority) + }), true) + + for i := 0; i < 5; i++ { + index.Set(txMeta[int]{priority: i}, i) + } + + frontNode := index.Front() + require.NotNil(t, frontNode) + require.Equal(t, 0, frontNode.Element.Value) +} + +func TestConcurrentSkipList_GetCount(t *testing.T) { + list := &ConcurrentSkipList[int]{ + mutex: sync.RWMutex{}, + priorityCounts: make(map[int]int), + } + + list.priorityCounts = nil + count := list.GetCount(10) + require.Equal(t, -1, count) + + list.priorityCounts = make(map[int]int) + list.priorityCounts[10] = 5 + list.priorityCounts[20] = 3 + + count = list.GetCount(10) + require.Equal(t, 5, count) + + count = list.GetCount(20) + require.Equal(t, 3, count) + + count = list.GetCount(30) + require.Equal(t, 0, count) +} + +func TestConcurrentSkipList_CloneCounts(t *testing.T) { + list := &ConcurrentSkipList[int]{ + mutex: sync.RWMutex{}, + priorityCounts: make(map[int]int), + } + list.priorityCounts = nil + counts := list.CloneCounts() + require.Nil(t, counts) + list.priorityCounts = map[int]int{ + 10: 5, + 20: 3, + 30: 7, + } + + counts = list.CloneCounts() + require.NotNil(t, counts) + require.Equal(t, 5, counts[10]) + require.Equal(t, 3, counts[20]) + require.Equal(t, 7, counts[30]) + + // check the cloned map is a separate instance after modified + counts[10] = 99 + require.Equal(t, 5, list.priorityCounts[10]) +} + +func TestConcurrentSkipList_GetScore(t *testing.T) { + list := &ConcurrentSkipList[int]{ + mutex: sync.RWMutex{}, + scores: map[scoreKey]score[int]{ + {nonce: 1, sender: "sender1"}: {Priority: 10, Weight: 1}, + {nonce: 2, sender: "sender2"}: {Priority: 20, Weight: 2}, + }, + } + + score := list.GetScore(1, "sender1") + require.NotNil(t, score) + require.Equal(t, 10, score.Priority) + require.Equal(t, 1, score.Weight) + + score = list.GetScore(1, "sender2") + require.Nil(t, score) + score = list.GetScore(2, "sender1") + require.Nil(t, score) + score = list.GetScore(3, "sender3") + require.Nil(t, score) +} + +func TestConcurrentSkipList_Get(t *testing.T) { + list := skiplist.New(skiplist.LessThanFunc(func(a, b any) int { + return skiplist.Uint64.Compare(b.(txMeta[int]).nonce, a.(txMeta[int]).nonce) + })) + concurrentList := &ConcurrentSkipList[int]{ + mutex: sync.RWMutex{}, + list: list, + } + + for i := 0; i < 5; i++ { + key := txMeta[int]{nonce: uint64(i)} + concurrentList.Set(key, i) + } + + key := txMeta[int]{nonce: 2} + ele := concurrentList.Get(key) + require.NotNil(t, ele) + require.Equal(t, 2, ele.Value) + + nonExist := txMeta[int]{nonce: 10} + ele = concurrentList.Get(nonExist) + require.Nil(t, ele) +} + +func TestConcurrentSkipList_Set(t *testing.T) { + list := skiplist.New(skiplist.LessThanFunc(func(a, b any) int { + return skiplist.Uint64.Compare(a.(txMeta[int]).nonce, b.(txMeta[int]).nonce) + })) + + concurrentList := &ConcurrentSkipList[int]{ + mutex: sync.RWMutex{}, + list: list, + priorityCounts: make(map[int]int), + scores: make(map[scoreKey]score[int]), + } + + key1 := txMeta[int]{nonce: 1, sender: "sender1", priority: 10, weight: 2} + ele1 := concurrentList.Set(key1, 1) + require.NotNil(t, ele1) + require.Equal(t, 1, ele1.Element.Value) + require.Equal(t, 1, concurrentList.priorityCounts[key1.priority]) + require.Equal(t, score[int]{Priority: 10, Weight: 2}, concurrentList.scores[scoreKey{nonce: 1, sender: "sender1"}]) + + // update existing element + key2 := txMeta[int]{nonce: 1, sender: "sender1", priority: 20, weight: 3} + ele2 := concurrentList.Set(key2, 2) + require.NotNil(t, ele2) + require.Equal(t, 2, ele2.Element.Value) + require.Equal(t, 1, concurrentList.priorityCounts[key2.priority]) + require.Equal(t, score[int]{Priority: 20, Weight: 3}, concurrentList.scores[scoreKey{nonce: 1, sender: "sender1"}]) + + // inserting new element + key3 := txMeta[int]{nonce: 2, sender: "sender2", priority: 15, weight: 1} + ele3 := concurrentList.Set(key3, 3) + require.NotNil(t, ele3) + require.Equal(t, 3, ele3.Element.Value) + require.Equal(t, 1, concurrentList.priorityCounts[key3.priority]) + require.Equal(t, score[int]{Priority: 15, Weight: 1}, concurrentList.scores[scoreKey{nonce: 2, sender: "sender2"}]) +} + +func TestConcurrentSkipList_Remove(t *testing.T) { + list := skiplist.New(skiplist.LessThanFunc(func(a, b any) int { + return skiplist.Uint64.Compare(a.(txMeta[int]).nonce, b.(txMeta[int]).nonce) + })) + + concurrentList := &ConcurrentSkipList[int]{ + mutex: sync.RWMutex{}, + list: list, + priorityCounts: make(map[int]int), + scores: make(map[scoreKey]score[int]), + } + + key1 := txMeta[int]{nonce: 1, sender: "sender1", priority: 10, weight: 2} + concurrentList.Set(key1, 1) + + key2 := txMeta[int]{nonce: 2, sender: "sender2", priority: 20, weight: 3} + concurrentList.Set(key2, 2) + + require.Equal(t, 1, concurrentList.priorityCounts[key1.priority]) + require.Equal(t, 1, concurrentList.priorityCounts[key2.priority]) + require.Equal(t, score[int]{Priority: 10, Weight: 2}, concurrentList.scores[scoreKey{nonce: 1, sender: "sender1"}]) + require.Equal(t, score[int]{Priority: 20, Weight: 3}, concurrentList.scores[scoreKey{nonce: 2, sender: "sender2"}]) + + concurrentList.Remove(key1) + require.Equal(t, 0, concurrentList.priorityCounts[key1.priority]) + require.NotContains(t, concurrentList.scores, scoreKey{nonce: key1.nonce, sender: key1.sender}) + require.Nil(t, concurrentList.Get(key1)) + + require.Equal(t, 1, concurrentList.priorityCounts[key2.priority]) + require.Equal(t, score[int]{Priority: 20, Weight: 3}, concurrentList.scores[scoreKey{nonce: 2, sender: "sender2"}]) +} + +func TestConcurrentPriorityIndex_Concurrent(t *testing.T) { + index := newConcurrentPriorityIndex[int](skiplist.LessThanFunc(func(a, b any) int { + return skiplist.Uint64.Compare(b.(txMeta[int]).nonce, a.(txMeta[int]).nonce) + }), true) + + total := 10 + for i := 0; i < total; i++ { + index.Set(txMeta[int]{nonce: uint64(i)}, i) + } + + wg := new(sync.WaitGroup) + wg.Add(2) + go func() { + for i := 0; i < total; i++ { + index.Set(txMeta[int]{nonce: uint64(i)}, i) + } + wg.Done() + }() + + go func() { + for i := 0; i < total; i++ { + index.Remove(txMeta[int]{nonce: uint64(i)}) + } + wg.Done() + }() + + for i := 0; i < total; i++ { + ele := index.Get(txMeta[int]{nonce: uint64(i)}) + if ele != nil { + require.Equal(t, i, ele.Value) + } + } + wg.Wait() +} diff --git a/types/mempool/priority_nonce.go b/types/mempool/priority_nonce.go index a927693410ef..2397764f428c 100644 --- a/types/mempool/priority_nonce.go +++ b/types/mempool/priority_nonce.go @@ -53,20 +53,17 @@ type ( // priority to other sender txs and must be partially ordered by both sender-nonce // and priority. PriorityNonceMempool[C comparable] struct { - mtx sync.Mutex - priorityIndex *skiplist.SkipList - priorityCounts map[C]int - senderIndices map[string]*skiplist.SkipList - scores map[txMeta[C]]txMeta[C] - cfg PriorityNonceMempoolConfig[C] + priorityIndex *ConcurrentSkipList[C] + senderIndices sync.Map + cfg PriorityNonceMempoolConfig[C] } // PriorityNonceIterator defines an iterator that is used for mempool iteration // on Select(). PriorityNonceIterator[C comparable] struct { mempool *PriorityNonceMempool[C] - priorityNode *skiplist.Element - senderCursors map[string]*skiplist.Element + priorityNode *ConcurrentListElement + senderCursors map[string]*ConcurrentListElement sender string nextPriority C } @@ -99,7 +96,7 @@ type ( // with the same priority weight C // senderElement is a pointer to the transaction's element in the sender index - senderElement *skiplist.Element + senderElement *ConcurrentListElement } ) @@ -165,11 +162,9 @@ func NewPriorityMempool[C comparable](cfg PriorityNonceMempoolConfig[C]) *Priori cfg.SignerExtractor = NewDefaultSignerExtractionAdapter() } mp := &PriorityNonceMempool[C]{ - priorityIndex: skiplist.New(skiplistComparable(cfg.TxPriority)), - priorityCounts: make(map[C]int), - senderIndices: make(map[string]*skiplist.SkipList), - scores: make(map[txMeta[C]]txMeta[C]), - cfg: cfg, + priorityIndex: newConcurrentPriorityIndex[C](skiplistComparable(cfg.TxPriority), true), + senderIndices: sync.Map{}, + cfg: cfg, } return mp @@ -184,12 +179,12 @@ func DefaultPriorityMempool() *PriorityNonceMempool[int64] { // i.e. the next valid transaction for the sender. If no such transaction exists, // nil will be returned. func (mp *PriorityNonceMempool[C]) NextSenderTx(sender string) sdk.Tx { - senderIndex, ok := mp.senderIndices[sender] + senderIndex, ok := mp.senderIndices.Load(sender) if !ok { return nil } - - cursor := senderIndex.Front() + senderIndexList := senderIndex.(*ConcurrentSkipList[C]) + cursor := senderIndexList.Front() return cursor.Value.(sdk.Tx) } @@ -203,8 +198,6 @@ func (mp *PriorityNonceMempool[C]) NextSenderTx(sender string) sdk.Tx { // Inserting a duplicate tx with a different priority overwrites the existing tx, // changing the total order of the mempool. func (mp *PriorityNonceMempool[C]) Insert(ctx context.Context, tx sdk.Tx) error { - mp.mtx.Lock() - defer mp.mtx.Unlock() if mp.cfg.MaxTx > 0 && mp.priorityIndex.Len() >= mp.cfg.MaxTx { return ErrMempoolTxMaxCapacity } else if mp.cfg.MaxTx < 0 { @@ -224,15 +217,14 @@ func (mp *PriorityNonceMempool[C]) Insert(ctx context.Context, tx sdk.Tx) error priority := mp.cfg.TxPriority.GetTxPriority(ctx, tx) nonce := sig.Sequence key := txMeta[C]{nonce: nonce, priority: priority, sender: sender} - - senderIndex, ok := mp.senderIndices[sender] + senderIndex, ok := mp.senderIndices.Load(sender) if !ok { - senderIndex = skiplist.New(skiplist.LessThanFunc(func(a, b any) int { + senderIndex = newConcurrentPriorityIndex[C](skiplist.LessThanFunc(func(a, b any) int { return skiplist.Uint64.Compare(b.(txMeta[C]).nonce, a.(txMeta[C]).nonce) - })) + }), false) // initialize sender index if not found - mp.senderIndices[sender] = senderIndex + mp.senderIndices.Store(sender, senderIndex) } // Since mp.priorityIndex is scored by priority, then sender, then nonce, a @@ -242,34 +234,32 @@ func (mp *PriorityNonceMempool[C]) Insert(ctx context.Context, tx sdk.Tx) error // // This O(log n) remove operation is rare and only happens when a tx's priority // changes. - sk := txMeta[C]{nonce: nonce, sender: sender} - if oldScore, txExists := mp.scores[sk]; txExists { - if mp.cfg.TxReplacement != nil && !mp.cfg.TxReplacement(oldScore.priority, priority, senderIndex.Get(key).Value.(sdk.Tx), tx) { - return fmt.Errorf( - "tx doesn't fit the replacement rule, oldPriority: %v, newPriority: %v, oldTx: %v, newTx: %v", - oldScore.priority, - priority, - senderIndex.Get(key).Value.(sdk.Tx), - tx, - ) + if oldScore := mp.priorityIndex.GetScore(nonce, sender); oldScore != nil { + if mp.cfg.TxReplacement != nil { + senderIndexList := senderIndex.(*ConcurrentSkipList[C]) + oldTx := senderIndexList.Get(key).Value.(sdk.Tx) + if !mp.cfg.TxReplacement(oldScore.Priority, priority, oldTx, tx) { + return fmt.Errorf( + "tx doesn't fit the replacement rule: old priority=%v, new priority=%v, old tx=%v, new tx=%v", + oldScore.Priority, + priority, + oldTx, + tx, + ) + } } mp.priorityIndex.Remove(txMeta[C]{ nonce: nonce, sender: sender, - priority: oldScore.priority, - weight: oldScore.weight, + priority: oldScore.Priority, + weight: oldScore.Weight, }) - mp.priorityCounts[oldScore.priority]-- } - mp.priorityCounts[priority]++ - // Since senderIndex is scored by nonce, a changed priority will overwrite the // existing key. - key.senderElement = senderIndex.Set(key, tx) - - mp.scores[sk] = txMeta[C]{priority: priority} + key.senderElement = senderIndex.(*ConcurrentSkipList[C]).Set(key, tx) mp.priorityIndex.Set(key, tx) return nil @@ -304,11 +294,16 @@ func (i *PriorityNonceIterator[C]) Next() Iterator { if i.priorityNode == nil { return nil } + senderIndexValue, ok := i.mempool.senderIndices.Load(i.sender) + if !ok { + return i.iteratePriority() + } + senderIndex := senderIndexValue.(*ConcurrentSkipList[C]) cursor, ok := i.senderCursors[i.sender] if !ok { // beginning of sender iteration - cursor = i.mempool.senderIndices[i.sender].Front() + cursor = senderIndex.Front() } else { // middle of sender iteration cursor = cursor.Next() @@ -325,11 +320,16 @@ func (i *PriorityNonceIterator[C]) Next() Iterator { // priority in the pool. if i.mempool.cfg.TxPriority.Compare(key.priority, i.nextPriority) < 0 { return i.iteratePriority() - } else if i.priorityNode.Next() != nil && i.mempool.cfg.TxPriority.Compare(key.priority, i.nextPriority) == 0 { + } + nextElem := i.priorityNode.Next() + if nextElem != nil && i.mempool.cfg.TxPriority.Compare(key.priority, i.nextPriority) == 0 { + var weight C + if score := i.mempool.priorityIndex.GetScore(key.nonce, key.sender); score != nil { + weight = score.Weight + } // Weight is incorporated into the priority index key only (not sender index) // so we must fetch it here from the scores map. - weight := i.mempool.scores[txMeta[C]{nonce: key.nonce, sender: key.sender}].weight - if i.mempool.cfg.TxPriority.Compare(weight, i.priorityNode.Next().Key().(txMeta[C]).weight) < 0 { + if i.mempool.cfg.TxPriority.Compare(weight, nextElem.Key().(txMeta[C]).weight) < 0 { return i.iteratePriority() } } @@ -352,8 +352,6 @@ func (i *PriorityNonceIterator[C]) Tx() sdk.Tx { // NOTE: It is not safe to use this iterator while removing transactions from // the underlying mempool. func (mp *PriorityNonceMempool[C]) Select(_ context.Context, _ [][]byte) Iterator { - mp.mtx.Lock() - defer mp.mtx.Unlock() if mp.priorityIndex.Len() == 0 { return nil } @@ -362,7 +360,7 @@ func (mp *PriorityNonceMempool[C]) Select(_ context.Context, _ [][]byte) Iterato iterator := &PriorityNonceIterator[C]{ mempool: mp, - senderCursors: make(map[string]*skiplist.Element), + senderCursors: make(map[string]*ConcurrentListElement), } return iterator.iteratePriority() @@ -376,11 +374,10 @@ type reorderKey[C comparable] struct { func (mp *PriorityNonceMempool[C]) reorderPriorityTies() { node := mp.priorityIndex.Front() - var reordering []reorderKey[C] for node != nil { key := node.Key().(txMeta[C]) - if mp.priorityCounts[key.priority] > 1 { + if mp.priorityIndex.GetCount(key.priority) > 1 { newKey := key newKey.weight = senderWeight(mp.cfg.TxPriority, key.senderElement) reordering = append(reordering, reorderKey[C]{deleteKey: key, insertKey: newKey, tx: node.Value.(sdk.Tx)}) @@ -391,9 +388,7 @@ func (mp *PriorityNonceMempool[C]) reorderPriorityTies() { for _, k := range reordering { mp.priorityIndex.Remove(k.deleteKey) - delete(mp.scores, txMeta[C]{nonce: k.deleteKey.nonce, sender: k.deleteKey.sender}) mp.priorityIndex.Set(k.insertKey, k.tx) - mp.scores[txMeta[C]{nonce: k.insertKey.nonce, sender: k.insertKey.sender}] = k.insertKey } } @@ -401,7 +396,7 @@ func (mp *PriorityNonceMempool[C]) reorderPriorityTies() { // defined as the first (nonce-wise) same sender tx with a priority not equal to // t. It is used to resolve priority collisions, that is when 2 or more txs from // different senders have the same priority. -func senderWeight[C comparable](txPriority TxPriority[C], senderCursor *skiplist.Element) C { +func senderWeight[C comparable](txPriority TxPriority[C], senderCursor *ConcurrentListElement) C { if senderCursor == nil { return txPriority.MinValue } @@ -422,16 +417,12 @@ func senderWeight[C comparable](txPriority TxPriority[C], senderCursor *skiplist // CountTx returns the number of transactions in the mempool. func (mp *PriorityNonceMempool[C]) CountTx() int { - mp.mtx.Lock() - defer mp.mtx.Unlock() return mp.priorityIndex.Len() } // Remove removes a transaction from the mempool in O(log n) time, returning an // error if unsuccessful. func (mp *PriorityNonceMempool[C]) Remove(tx sdk.Tx) error { - mp.mtx.Lock() - defer mp.mtx.Unlock() sigs, err := mp.cfg.SignerExtractor.GetSigners(tx) if err != nil { return err @@ -444,22 +435,19 @@ func (mp *PriorityNonceMempool[C]) Remove(tx sdk.Tx) error { sender := sig.Signer.String() nonce := sig.Sequence - scoreKey := txMeta[C]{nonce: nonce, sender: sender} - score, ok := mp.scores[scoreKey] - if !ok { + score := mp.priorityIndex.GetScore(nonce, sender) + if score == nil { return ErrTxNotFound } - tk := txMeta[C]{nonce: nonce, priority: score.priority, sender: sender, weight: score.weight} + tk := txMeta[C]{nonce: nonce, priority: score.Priority, sender: sender, weight: score.Weight} - senderTxs, ok := mp.senderIndices[sender] + senderTxs, ok := mp.senderIndices.Load(sender) if !ok { return fmt.Errorf("sender %s not found", sender) } mp.priorityIndex.Remove(tk) - senderTxs.Remove(tk) - delete(mp.scores, scoreKey) - mp.priorityCounts[score.priority]-- + senderTxs.(*ConcurrentSkipList[C]).Remove(tk) return nil } @@ -470,24 +458,26 @@ func IsEmpty[C comparable](mempool Mempool) error { return errors.New("priorityIndex not empty") } - countKeys := make([]C, 0, len(mp.priorityCounts)) - for k := range mp.priorityCounts { - countKeys = append(countKeys, k) - } - - for _, k := range countKeys { - if mp.priorityCounts[k] != 0 { - return fmt.Errorf("priorityCounts not zero at %v, got %v", k, mp.priorityCounts[k]) + priorityCounts := mp.priorityIndex.CloneCounts() + for k, count := range priorityCounts { + if count != 0 { + return fmt.Errorf("priorityCounts not zero at %v, got %v", k, count) } } - senderKeys := make([]string, 0, len(mp.senderIndices)) - for k := range mp.senderIndices { - senderKeys = append(senderKeys, k) - } + senderKeys := make([]string, 0) + mp.senderIndices.Range(func(key, value interface{}) bool { + senderKeys = append(senderKeys, key.(string)) + return true + }) for _, k := range senderKeys { - if mp.senderIndices[k].Len() != 0 { + senderIndexValue, ok := mp.senderIndices.Load(k) + if !ok { + continue + } + senderIndex := senderIndexValue.(*ConcurrentSkipList[C]) + if senderIndex.Len() != 0 { return fmt.Errorf("senderIndex not empty for sender %v", k) } } diff --git a/types/mempool/priority_nonce_test.go b/types/mempool/priority_nonce_test.go index 0a2f40355fbd..7d793040b56d 100644 --- a/types/mempool/priority_nonce_test.go +++ b/types/mempool/priority_nonce_test.go @@ -4,6 +4,7 @@ import ( "fmt" "math" "math/rand" + "sync" "testing" "time" @@ -383,14 +384,28 @@ func (s *MempoolTestSuite) TestIterator() { } // iterate through txs + wg := new(sync.WaitGroup) + wg.Add(1) + go func() { + for j := len(tt.txs); j < len(tt.txs)+100; j++ { + tx := testTx{id: j, priority: int64(rand.Intn(100)), nonce: uint64(j), address: sa} + c := ctx.WithPriority(tx.priority) + _ = pool.Insert(c, tx) + } + wg.Done() + }() + iterator := pool.Select(ctx, nil) for iterator != nil { tx := iterator.Tx().(testTx) - require.Equal(t, tt.txs[tx.id].p, int(tx.priority)) - require.Equal(t, tt.txs[tx.id].n, int(tx.nonce)) - require.Equal(t, tt.txs[tx.id].a, tx.address) + if tx.id < len(tt.txs) { + require.Equal(t, tt.txs[tx.id].p, int(tx.priority)) + require.Equal(t, tt.txs[tx.id].n, int(tx.nonce)) + require.Equal(t, tt.txs[tx.id].a, tx.address) + } iterator = iterator.Next() } + wg.Wait() }) } }