Skip to content

Commit

Permalink
abci: skip repeated tx sig verification in proposal processing (#840)
Browse files Browse the repository at this point in the history
* abci: verfied sigs optim

This change avoids the need for a block validator to re-verify all
transaction signatures in a proposed block if they had already
verified the signatures in a previous CheckTx.

This also includes some kwild tweaks to logging to be less spammy

* populate abci validatorAddressToPubKey in constructor, not Info
  • Loading branch information
jchappelow authored Jun 20, 2024
1 parent 5fd79a4 commit 8a9318f
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 52 deletions.
5 changes: 5 additions & 0 deletions cmd/kwild/server/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,11 @@ func buildStatesyncer(d *coreDependencies) *statesync.StateSyncer {
continue
}

if res.Header.Height == 0 {
d.log.Warnf("zero height from provider %v", p)
continue
}

// Get the trust height and trust hash from the remote server
d.cfg.ChainCfg.StateSync.TrustHeight = res.Header.Height
d.cfg.ChainCfg.StateSync.TrustHash = res.Header.Hash().String()
Expand Down
125 changes: 78 additions & 47 deletions internal/abci/abci.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ import (

abciTypes "github.com/cometbft/cometbft/abci/types"
"github.com/cometbft/cometbft/crypto/ed25519"
"github.com/cometbft/cometbft/crypto/tmhash"
cmtTypes "github.com/cometbft/cometbft/types"
"go.uber.org/zap"
)

Expand Down Expand Up @@ -66,21 +64,30 @@ func NewAbciApp(ctx context.Context, cfg *AbciConfig, snapshotter SnapshotModule
log: log,

validatorAddressToPubKey: make(map[string][]byte),
txCache: make(map[string]bool),
verifiedTxns: make(map[chainHash]struct{}),
}
app.forks.FromMap(cfg.ForkHeights)

tx, err := db.BeginOuterTx(ctx)
if err != nil {
return nil, fmt.Errorf("failed to begin outer tx: %w", err)
}
defer tx.Rollback(ctx)

app.forks.FromMap(cfg.ForkHeights)
// Populate the validatorAddressToPubKey field.
validators, err := txRouter.GetValidators(ctx, tx)
if err != nil {
return nil, fmt.Errorf("failed to get validators: %w", err)
}
for _, val := range validators {
addr, err := pubkeyToAddr(val.PubKey)
if err != nil {
return nil, fmt.Errorf("failed to convert pubkey to address: %w", err)
}

app.validatorAddressToPubKey[addr] = val.PubKey
}

// Enable any dynamically registered payloads, encoders, etc. from
// extension-defined forks that must be enabled by the current height. In
// addition to node restart, this is where forks with genesis height (0)
// activation are enabled since the first FinalizeBlock is for height 1.
height, appHash, err := meta.GetChainState(ctx, tx)
if err != nil {
return nil, err
Expand All @@ -92,6 +99,10 @@ func NewAbciApp(ctx context.Context, cfg *AbciConfig, snapshotter SnapshotModule

app.log.Infof("Preparing ABCI application at height %v, appHash %x", height, appHash)

// Enable any dynamically registered payloads, encoders, etc. from
// extension-defined forks that must be enabled by the current height. In
// addition to node restart, this is where forks with genesis height (0)
// activation are enabled since the first FinalizeBlock is for height 1.
activeForks := app.forks.ActivatedBy(uint64(height))
slices.SortStableFunc(activeForks, forks.ForkSortFunc)
for _, fork := range activeForks {
Expand Down Expand Up @@ -195,6 +206,8 @@ func proposerAddrToString(addr []byte) string {
return strings.ToUpper(hex.EncodeToString(addr))
}

type chainHash = [32]byte

type AbciApp struct {
// db is a connection to the database
db DB
Expand Down Expand Up @@ -229,15 +242,19 @@ type AbciApp struct {

broadcastFn EventBroadcaster

// validatorAddressToPubKey is a map of validator addresses to their public keys
// validatorAddressToPubKey is a map of validator addresses to their public
// keys. It should only be accessed from consensus connection methods, which
// are not called concurrently, or the constructor.
validatorAddressToPubKey map[string][]byte

// txCache stores hashes of all the transactions currently in the mempool.
// This is used to avoid recomputing the hash for all mempool transactions
// on every TxQuery request (to mitigate Potential DDOS attack vector).
// verifiedTxns stores hashes of all the transactions currently in the
// mempool, which have passed signature verification. This is used to avoid
// recomputing the hash for all mempool transactions on every TxQuery
// request (to mitigate Potential DDOS attack vector).
// https://github.com/kwilteam/kwil-db/issues/714
txCache map[string]bool
txCacheMu sync.RWMutex
verifiedTxnsMtx sync.RWMutex
verifiedTxns map[chainHash]struct{}

// halted is set to true when the network is halted for migration.
halted atomic.Bool
}
Expand Down Expand Up @@ -294,7 +311,6 @@ var _ abciTypes.Application = &AbciApp{}
func (a *AbciApp) CheckTx(ctx context.Context, incoming *abciTypes.RequestCheckTx) (*abciTypes.ResponseCheckTx, error) {
newTx := incoming.Type == abciTypes.CheckTxType_New
logger := a.log.With(zap.Bool("recheck", !newTx))
logger.Debug("check tx")

if a.halted.Load() {
return &abciTypes.ResponseCheckTx{Code: codeInvalidTxType.Uint32(), Log: "network is halted for migration"}, nil
Expand All @@ -316,7 +332,7 @@ func (a *AbciApp) CheckTx(ctx context.Context, incoming *abciTypes.RequestCheckT
return &abciTypes.ResponseCheckTx{Code: code.Uint32(), Log: err.Error()}, nil // return error now or is it still all about code?
}

logger.Debug("",
logger.Debug("check tx",
zap.String("sender", hex.EncodeToString(tx.Sender)),
zap.String("PayloadType", tx.Body.PayloadType.String()),
zap.Uint64("nonce", tx.Body.Nonce))
Expand All @@ -340,7 +356,7 @@ func (a *AbciApp) CheckTx(ctx context.Context, incoming *abciTypes.RequestCheckT
return &abciTypes.ResponseCheckTx{Code: code.Uint32(), Log: err.Error()}, nil
}
} else {
logger.Info("Recheck", zap.String("sender", hex.EncodeToString(tx.Sender)), zap.Uint64("nonce", tx.Body.Nonce), zap.String("payloadType", tx.Body.PayloadType.String()))
logger.Debug("Recheck", zap.String("sender", hex.EncodeToString(tx.Sender)), zap.Uint64("nonce", tx.Body.Nonce), zap.String("payloadType", tx.Body.PayloadType.String()))
}

readTx, err := a.db.BeginReadTx(ctx)
Expand All @@ -364,15 +380,20 @@ func (a *AbciApp) CheckTx(ctx context.Context, incoming *abciTypes.RequestCheckT
code = codeUnknownError
logger.Warn("unexpected failure to verify transaction against local mempool state", zap.Error(err))
}
// Evicting transaction from mempool.
txHash := sha256.Sum256(incoming.Tx)
a.verifiedTxnsMtx.Lock()
delete(a.verifiedTxns, txHash)
a.verifiedTxnsMtx.Unlock()
return &abciTypes.ResponseCheckTx{Code: code.Uint32(), Log: err.Error()}, nil
}

// Cache the transaction hash
if newTx {
txHash := cmtTypes.Tx(incoming.Tx).Hash()
a.txCacheMu.Lock()
defer a.txCacheMu.Unlock()
a.txCache[string(txHash)] = true
txHash := sha256.Sum256(incoming.Tx)
a.verifiedTxnsMtx.Lock()
a.verifiedTxns[txHash] = struct{}{}
a.verifiedTxnsMtx.Unlock()
}
return &abciTypes.ResponseCheckTx{Code: code.Uint32()}, nil
}
Expand Down Expand Up @@ -459,16 +480,18 @@ func (a *AbciApp) FinalizeBlock(ctx context.Context, req *abciTypes.RequestFinal
return nil, fmt.Errorf("failed to unmarshal transaction: %w", err)
}

txHash := sha256.Sum256(tx)

txRes := a.txApp.Execute(txapp.TxContext{
Ctx: ctx,
TxID: tmhash.Sum(tx), // use cometbft TmHash to get the same hash as is indexed
TxID: txHash[:], // tmhash.Sum(tx), // use cometbft TmHash to get the same hash as is indexed
BlockContext: &blockCtx,
}, a.consensusTx, decoded)

abciRes := &abciTypes.ExecTxResult{}
if txRes.Error != nil {
abciRes.Log = txRes.Error.Error()
a.log.Warn("failed to execute transaction", zap.Error(txRes.Error))
a.log.Debug("failed to execute transaction", zap.Error(txRes.Error))
} else {
abciRes.Log = "success"
}
Expand All @@ -478,10 +501,9 @@ func (a *AbciApp) FinalizeBlock(ctx context.Context, req *abciTypes.RequestFinal
res.TxResults = append(res.TxResults, abciRes)

// Remove the transaction from the cache as it has been included in a block
hash := cmtTypes.Tx(tx).Hash()
a.txCacheMu.Lock()
delete(a.txCache, string(hash))
a.txCacheMu.Unlock()
a.verifiedTxnsMtx.Lock()
delete(a.verifiedTxns, txHash)
a.verifiedTxnsMtx.Unlock()
}

// If at activation height, submit any consensus params updates associated
Expand Down Expand Up @@ -706,20 +728,6 @@ func (a *AbciApp) Info(ctx context.Context, _ *abciTypes.RequestInfo) (*abciType
height = 0 // for ChainInfo caller, non-negative is expected for genesis
}

validators, err := a.txApp.GetValidators(ctx, readTx)
if err != nil {
return nil, fmt.Errorf("failed to get validators: %w", err)
}

for _, val := range validators {
addr, err := pubkeyToAddr(val.PubKey)
if err != nil {
return nil, fmt.Errorf("failed to convert pubkey to address: %w", err)
}

a.validatorAddressToPubKey[addr] = val.PubKey
}

a.log.Info("ABCI application is ready", zap.Int64("height", height))

return &abciTypes.ResponseInfo{
Expand Down Expand Up @@ -1199,7 +1207,8 @@ func (a *AbciApp) validateProposalTransactions(ctx context.Context, txns [][]byt
}
expectedNonce := uint64(nonce) + 1

for _, tx := range txs {
for _, txO := range txs {
tx := txO.tx
if tx.Body.Nonce != expectedNonce {
logger.Warn("nonce mismatch", zap.Uint64("txNonce", tx.Body.Nonce),
zap.Uint64("expectedNonce", expectedNonce), zap.String("nonces", fmt.Sprintf("%v", nonceList(txs))))
Expand Down Expand Up @@ -1245,8 +1254,12 @@ func (a *AbciApp) validateProposalTransactions(ctx context.Context, txns [][]byt

// This block proposal may include transactions that did not pass
// through our mempool, so we have to verify all signatures.
if err = ident.VerifyTransaction(tx); err != nil {
return fmt.Errorf("transaction signature verification failed: %w", err)
if !a.TxSigVerified(txO.hash) {
if err = ident.VerifyTransaction(tx); err != nil {
return fmt.Errorf("transaction signature verification failed: %w", err)
}
// We won't bother to insert this hash into the map since it is
// very likely that this transaction is about to be executed.
}
}
}
Expand Down Expand Up @@ -1334,10 +1347,28 @@ func (a *AbciApp) SetEventBroadcaster(fn EventBroadcaster) {
a.broadcastFn = fn
}

// TxSigVerified indicates if ABCI has verified this unconfirmed transaction's
// signature. This also returns false if the transaction is not in mempool. This
// logic is not broadly applicable, but since the tx hash is computed over the
// entire serialized transaction including the signature, the same hash implies
// the same signature.
func (a *AbciApp) TxSigVerified(txHash chainHash) bool {
a.verifiedTxnsMtx.Lock()
defer a.verifiedTxnsMtx.Unlock()
_, ok := a.verifiedTxns[txHash]
return ok
}

// TxInMempool wraps TxSigVerified for callers that require a slice to check if
// a transaction is (still) in mempool.
func (a *AbciApp) TxInMempool(txHash []byte) bool {
a.txCacheMu.Lock()
defer a.txCacheMu.Unlock()
_, ok := a.txCache[string(txHash)]
if len(txHash) != 32 {
return false
}
hash := [32]byte(txHash) // requires go 1.20
a.verifiedTxnsMtx.Lock()
defer a.verifiedTxnsMtx.Unlock()
_, ok := a.verifiedTxns[hash]
return ok
}

Expand Down
20 changes: 15 additions & 5 deletions internal/abci/utils.go
Original file line number Diff line number Diff line change
@@ -1,31 +1,41 @@
package abci

import (
"crypto/sha256"

"github.com/kwilteam/kwil-db/common/chain"
"github.com/kwilteam/kwil-db/core/types/transactions"
"github.com/kwilteam/kwil-db/extensions/consensus"
)

type txOut struct {
tx *transactions.Transaction
hash [32]byte
}

// groupTransactions groups the transactions by sender.
func groupTxsBySender(txns [][]byte) (map[string][]*transactions.Transaction, error) {
grouped := make(map[string][]*transactions.Transaction)
func groupTxsBySender(txns [][]byte) (map[string][]*txOut, error) {
grouped := make(map[string][]*txOut)
for _, tx := range txns {
t := &transactions.Transaction{}
err := t.UnmarshalBinary(tx)
if err != nil {
return nil, err
}
key := string(t.Sender)
grouped[key] = append(grouped[key], t)
grouped[key] = append(grouped[key], &txOut{
tx: t,
hash: sha256.Sum256(tx),
})
}
return grouped, nil
}

// nonceList is for debugging
func nonceList(txns []*transactions.Transaction) []uint64 {
func nonceList(txns []*txOut) []uint64 {
nonces := make([]uint64, len(txns))
for i, tx := range txns {
nonces[i] = tx.Body.Nonce
nonces[i] = tx.tx.Body.Nonce
}
return nonces
}
Expand Down

0 comments on commit 8a9318f

Please sign in to comment.