Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize AddBlock logic for pos fee estimator #1218

Open
wants to merge 1 commit into
base: feature/proof-of-stake
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 62 additions & 28 deletions lib/pos_fee_estimator.go
Original file line number Diff line number Diff line change
@@ -83,7 +83,7 @@ func (posFeeEstimator *PoSFeeEstimator) Init(
posFeeEstimator.pastBlocksTransactionRegister.Init(globalParams.Copy())

// Sort the past blocks by height just to be safe.
sortedPastBlocks := posFeeEstimator.cleanUpPastBlocks(pastBlocks)
sortedPastBlocks, _ := posFeeEstimator.cleanUpPastBlocks(pastBlocks)

// Add all the txns from the past blocks to the new pastBlocksTransactionRegister.
for _, block := range sortedPastBlocks {
@@ -116,22 +116,36 @@ func (posFeeEstimator *PoSFeeEstimator) AddBlock(block *MsgDeSoBlock) error {
func (posFeeEstimator *PoSFeeEstimator) addBlockNoLock(block *MsgDeSoBlock) error {
// Create a new slice to house the new past blocks and add the new block to it.
newPastBlocks := append(posFeeEstimator.cachedBlocks, block)
newPastBlocks = posFeeEstimator.cleanUpPastBlocks(newPastBlocks)
var removedBlocks []*MsgDeSoBlock
newPastBlocks, removedBlocks = posFeeEstimator.cleanUpPastBlocks(newPastBlocks)

// Create a clean transaction register to add the blocks' transactions.
newTransactionRegister := NewTransactionRegister()
newTransactionRegister.Init(posFeeEstimator.globalParams.Copy())
incomingBlockHash, err := block.Hash()
if err != nil {
return errors.Wrap(err, "PoSFeeEstimator.addBlockNoLock: error computing blockHash")
}

// Remove all blocks that were pruned are no longer in the cached blocks.
for _, removedBlock := range removedBlocks {
if err = posFeeEstimator.removeBlockNoLock(removedBlock); err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This operation ends up reinitializing the transaction register for every block that's removed. I think it ends up being even less efficient than the original code

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removeBlockNoLock was updated to modify the transaction register in place w/o creating a new one.

return errors.Wrap(err, "PoSFeeEstimator.addBlockNoLock: error removing block from PoSFeeEstimator")
}
}

incomingBlockRemoved := collections.Any(removedBlocks, func(removedBlock *MsgDeSoBlock) bool {
// It's not possible for us to have added a block that can't be hashed.
removedBlockHash, _ := removedBlock.Hash()
return removedBlockHash.IsEqual(incomingBlockHash)
})

// Add all transactions from the block to the pastBlocksTransactionRegister.
for _, pastBlock := range newPastBlocks {
if err := addBlockToTransactionRegister(newTransactionRegister, pastBlock); err != nil {
if !incomingBlockRemoved {
// Add all transactions from the block to the pastBlocksTransactionRegister.
if err := addBlockToTransactionRegister(posFeeEstimator.pastBlocksTransactionRegister, block); err != nil {
return errors.Wrap(err, "PoSFeeEstimator.addBlockNoLock: error adding block to pastBlocksTransactionRegister")
}
}

// Update the cached blocks and pastBlocksTransactionRegister.
// Update the cached blocks
posFeeEstimator.cachedBlocks = newPastBlocks
posFeeEstimator.pastBlocksTransactionRegister = newTransactionRegister

return nil
}
@@ -157,6 +171,25 @@ func addBlockToTransactionRegister(txnRegister *TransactionRegister, block *MsgD
return nil
}

func removeBlockFromTransactionRegister(txnRegister *TransactionRegister, block *MsgDeSoBlock) error {
for _, txn := range block.Txns {
// We explicitly exclude block reward transactions as they do not have fees and were never
// added in the first place.
if txn.TxnMeta.GetTxnType() == TxnTypeBlockReward {
continue
}
mtxn, err := NewMempoolTx(txn, NanoSecondsToTime(block.Header.TstampNanoSecs), block.Header.Height)
if err != nil {
return errors.Wrap(err, "PoSFeeEstimator.removeBlockFromTransactionRegister: error creating MempoolTx")
}
if err = txnRegister.RemoveTransaction(mtxn); err != nil {
return errors.Wrap(err,
"PoSFeeEstimator.removeBlockFromTransactionRegister: error removing txn from pastBlocksTransactionRegister")
}
}
return nil
}

// RemoveBlock removes a block from the PoSFeeEstimator. This will remove all the transactions from the block
// from the pastBlocksTransactionRegister and remove the block from the cache.
func (posFeeEstimator *PoSFeeEstimator) RemoveBlock(block *MsgDeSoBlock) error {
@@ -186,20 +219,14 @@ func (posFeeEstimator *PoSFeeEstimator) removeBlockNoLock(block *MsgDeSoBlock) e
return !blockHash.IsEqual(cachedBlockHash)
})

// Create a clean transaction register to add the blocks' transactions.
newTransactionRegister := NewTransactionRegister()
newTransactionRegister.Init(posFeeEstimator.globalParams.Copy())

// Add all transactions from the past blocks to the transaction register.
for _, pastBlock := range newPastBlocks {
if err := addBlockToTransactionRegister(newTransactionRegister, pastBlock); err != nil {
return errors.Wrap(err, "PoSFeeEstimator.removeBlockNoLock: error adding block to transaction register")
}
// Remove the block from the transaction register
if err = removeBlockFromTransactionRegister(posFeeEstimator.pastBlocksTransactionRegister, block); err != nil {
return errors.Wrap(err,
"PoSFeeEstimator.removeBlockNoLock: error removing block from pastBlocksTransactionRegister")
}

// Update the cached blocks and pastBlocksTransactionRegister.
posFeeEstimator.cachedBlocks = newPastBlocks
posFeeEstimator.pastBlocksTransactionRegister = newTransactionRegister

return nil
}
@@ -226,23 +253,28 @@ func (posFeeEstimator *PoSFeeEstimator) UpdateGlobalParams(globalParams *GlobalP
}

// cleanUpPastBlocks cleans up the input blocks slice, deduping, sorting, and pruning the blocks by height.
func (posFeeEstimator *PoSFeeEstimator) cleanUpPastBlocks(blocks []*MsgDeSoBlock) []*MsgDeSoBlock {
dedupedBlocks := posFeeEstimator.dedupeBlocksByBlockHeight(blocks)
func (posFeeEstimator *PoSFeeEstimator) cleanUpPastBlocks(blocks []*MsgDeSoBlock) ([]*MsgDeSoBlock, []*MsgDeSoBlock) {
dedupedBlocks, dupeBlocksRemoved := posFeeEstimator.dedupeBlocksByBlockHeight(blocks)
sortedBlocks := posFeeEstimator.sortBlocksByBlockHeight(dedupedBlocks)
return posFeeEstimator.pruneBlocksToMaxNumPastBlocks(sortedBlocks)
cleanedUpPastBlocks, prunedBlocks := posFeeEstimator.pruneBlocksToMaxNumPastBlocks(sortedBlocks)
return cleanedUpPastBlocks, append(dupeBlocksRemoved, prunedBlocks...)
}

// dedupeBlocksByBlockHeight deduplicates the blocks by block height. If multiple blocks have the same
// height, it keeps the one with the highest view.
func (posFeeEstimator *PoSFeeEstimator) dedupeBlocksByBlockHeight(blocks []*MsgDeSoBlock) []*MsgDeSoBlock {
func (posFeeEstimator *PoSFeeEstimator) dedupeBlocksByBlockHeight(blocks []*MsgDeSoBlock) (
_dedupedBlocks []*MsgDeSoBlock, _prunedBlocks []*MsgDeSoBlock) {
blocksByBlockHeight := make(map[uint64]*MsgDeSoBlock)
removedBlocks := make([]*MsgDeSoBlock, 0)
for _, block := range blocks {
existingBlock, hasExistingBlock := blocksByBlockHeight[block.Header.Height]
if !hasExistingBlock || existingBlock.Header.GetView() < block.Header.GetView() {
blocksByBlockHeight[block.Header.Height] = block
} else {
removedBlocks = append(removedBlocks, block)
}
}
return collections.MapValues(blocksByBlockHeight)
return collections.MapValues(blocksByBlockHeight), removedBlocks
}

// sortBlocksByBlockHeightAndTstamp sorts the blocks by height.
@@ -268,15 +300,17 @@ func (posFeeEstimator *PoSFeeEstimator) sortBlocksByBlockHeight(blocks []*MsgDeS
}

// pruneBlocksToMaxNumPastBlocks reduces the number of blocks to the numPastBlocks param
func (posFeeEstimator *PoSFeeEstimator) pruneBlocksToMaxNumPastBlocks(blocks []*MsgDeSoBlock) []*MsgDeSoBlock {
func (posFeeEstimator *PoSFeeEstimator) pruneBlocksToMaxNumPastBlocks(blocks []*MsgDeSoBlock) (
_cachedBlocks []*MsgDeSoBlock, _prunedBlocks []*MsgDeSoBlock) {
numCachedBlocks := uint64(len(blocks))
if numCachedBlocks <= posFeeEstimator.numPastBlocks {
return blocks
return blocks, nil
}

// Prune the blocks with the lowest block heights. We do this by removing the
// first len(blocks) - numPastBlocks blocks from the blocks slice.
return blocks[numCachedBlocks-posFeeEstimator.numPastBlocks:]
return blocks[numCachedBlocks-posFeeEstimator.numPastBlocks:],
blocks[:numCachedBlocks-posFeeEstimator.numPastBlocks]
}

// EstimateFeeRateNanosPerKB estimates the fee rate in nanos per KB for the current mempool