Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[txpool] Long time consuming ddos attack protection #254

Merged
merged 7 commits into from
Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions chain/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ type Params struct {
Engine map[string]interface{} `json:"engine"`
BlockGasTarget uint64 `json:"blockGasTarget"`
BlackList []string `json:"blackList,omitempty"`
DDOSPretection bool `json:"ddosPretection,omitempty"`
}

func (p *Params) GetEngine() string {
Expand Down
5 changes: 2 additions & 3 deletions consensus/ibft/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ import (
)

const (
KeyType = "type"
KeyEpochSize = "epochSize"
KeyBanishAbnormalContract = "banishAbnormalContract"
KeyType = "type"
KeyEpochSize = "epochSize"
)

// Define the type of the IBFT consensus
Expand Down
218 changes: 133 additions & 85 deletions consensus/ibft/ibft.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ import (
)

const (
DefaultEpochSize = 100000
DefaultBanishAbnormalContract = false // banish abnormal contract whose execution consumes too much time.
DefaultEpochSize = 100000
// banish abnormal contract whose execution consumes too much time
DefaultBanishAbnormalContract = false
)

var (
Expand All @@ -54,7 +55,13 @@ type blockchainInterface interface {
CalculateGasLimit(number uint64) (uint64, error)
}

type ddosProtectionInterface interface {
IsDDOSTx(tx *types.Transaction) bool
MarkDDOSTx(tx *types.Transaction)
}

type txPoolInterface interface {
ddosProtectionInterface
Drop(tx *types.Transaction)
DemoteAllPromoted(tx *types.Transaction, correctNonce uint64)
ResetWithHeaders(headers ...*types.Header)
Expand Down Expand Up @@ -116,9 +123,9 @@ type Ibft struct {
// Dynamic References for signing and validating
currentTxSigner crypto.TxSigner // Tx Signer at current sequence
currentValidators validator.Validators // Validator set at current sequence
// for banishing some exhausting contracts
banishAbnormalContract bool
exhaustingContracts map[types.Address]struct{}
// Recording resource exhausting contracts
// but would not banish it until it became a real ddos attack
exhaustingContracts map[types.Address]struct{}
}

// runHook runs a specified hook if it is present in the hook map
Expand Down Expand Up @@ -170,36 +177,23 @@ func Factory(
}
}

var banishAbnormalContract bool
if definedBanish, ok := params.Config.Config[KeyBanishAbnormalContract]; !ok {
banishAbnormalContract = DefaultBanishAbnormalContract
} else {
banish, ok := definedBanish.(bool)
if !ok {
return nil, errors.New("banishAbnormalContract invalid type assertion")
}

banishAbnormalContract = banish
}

p := &Ibft{
logger: params.Logger.Named("ibft"),
config: params.Config,
Grpc: params.Grpc,
blockchain: params.Blockchain,
executor: params.Executor,
closeCh: make(chan struct{}),
isClosed: atomic.NewBool(false),
txpool: params.Txpool,
state: &currentstate.CurrentState{},
network: params.Network,
epochSize: epochSize,
sealing: params.Seal,
metrics: params.Metrics,
secretsManager: params.SecretsManager,
blockTime: time.Duration(params.BlockTime) * time.Second,
banishAbnormalContract: banishAbnormalContract,
exhaustingContracts: make(map[types.Address]struct{}),
logger: params.Logger.Named("ibft"),
config: params.Config,
Grpc: params.Grpc,
blockchain: params.Blockchain,
executor: params.Executor,
closeCh: make(chan struct{}),
isClosed: atomic.NewBool(false),
txpool: params.Txpool,
state: &currentstate.CurrentState{},
network: params.Network,
epochSize: epochSize,
sealing: params.Seal,
metrics: params.Metrics,
secretsManager: params.SecretsManager,
blockTime: time.Duration(params.BlockTime) * time.Second,
exhaustingContracts: make(map[types.Address]struct{}),
}

// Initialize the mechanism
Expand Down Expand Up @@ -714,47 +708,12 @@ func (i *Ibft) buildBlock(snap *Snapshot, parent *types.Header) (*types.Block, e

// insert system transactions at last to ensure it works
if i.shouldWriteSystemTransactions(header.Number) {
txn := transition.Txn()

// make slash tx if needed
if i.currentRound() > 0 {
// only punish the first validator
lastBlockProposer, _ := ecrecoverFromHeader(parent)

needPunished := i.state.CalcNeedPunished(i.currentRound(), lastBlockProposer)
if len(needPunished) > 0 {
tx, err := i.makeTransitionSlashTx(txn, header.Number, needPunished[0])
if err != nil {
return nil, err
}

// system transaction, increase gas limit if needed
increaseHeaderGasIfNeeded(transition, header, tx)

// execute slash tx
if err := transition.Write(tx); err != nil {
return nil, err
}

txs = append(txs, tx)
}
}

// make deposit tx
tx, err := i.makeTransitionDepositTx(transition.Txn(), header.Number)
systemTxs, err := i.writeSystemTxs(transition, parent, header)
if err != nil {
return nil, err
}

// system transaction, increase gas limit if needed
increaseHeaderGasIfNeeded(transition, header, tx)

// execute deposit tx
if err := transition.Write(tx); err != nil {
return nil, err
}

txs = append(txs, tx)
txs = append(txs, systemTxs...)
}

if err := i.PreStateCommit(header, transition); err != nil {
Expand Down Expand Up @@ -804,6 +763,84 @@ func (i *Ibft) buildBlock(snap *Snapshot, parent *types.Header) (*types.Block, e
return block, nil
}

func (i *Ibft) writeSystemSlashTx(
transition *state.Transition,
parent, header *types.Header,
) (*types.Transaction, error) {
if i.currentRound() == 0 {
// no need slashing
return nil, nil
}

// only punish the first validator
lastBlockProposer, _ := ecrecoverFromHeader(parent)

needPunished := i.state.CalcNeedPunished(i.currentRound(), lastBlockProposer)
if len(needPunished) == 0 {
// it shouldn't be, but we still need to prevent overwhelming
return nil, nil
}

tx, err := i.makeTransitionSlashTx(transition.Txn(), header.Number, needPunished[0])
if err != nil {
return nil, err
}

// system transaction, increase gas limit if needed
increaseHeaderGasIfNeeded(transition, header, tx)

// execute slash tx
if err := transition.Write(tx); err != nil {
return nil, err
}

return tx, nil
}

func (i *Ibft) writeSystemDepositTx(
transition *state.Transition,
header *types.Header,
) (*types.Transaction, error) {
// make deposit tx
tx, err := i.makeTransitionDepositTx(transition.Txn(), header.Number)
if err != nil {
return nil, err
}

// system transaction, increase gas limit if needed
increaseHeaderGasIfNeeded(transition, header, tx)

// execute deposit tx
if err := transition.Write(tx); err != nil {
return nil, err
}

return tx, nil
}

func (i *Ibft) writeSystemTxs(
transition *state.Transition,
parent, header *types.Header,
) (txs []*types.Transaction, err error) {
// slash transaction
slashTx, err := i.writeSystemSlashTx(transition, parent, header)
if err != nil {
return nil, err
} else if slashTx != nil {
txs = append(txs, slashTx)
}

// deposit transaction
depositTx, err := i.writeSystemDepositTx(transition, header)
if err != nil {
return nil, err
}

txs = append(txs, depositTx)

return txs, nil
}

func increaseHeaderGasIfNeeded(transition *state.Transition, header *types.Header, tx *types.Transaction) {
if transition.TotalGas()+tx.Gas <= header.GasLimit {
return
Expand Down Expand Up @@ -936,12 +973,18 @@ func (i *Ibft) writeTransactions(
break
}

if i.shouldBanishTx(tx) {
i.logger.Info("banish some exausting contract and drop all sender transactions",
if i.shouldMarkLongConsumingTx(tx) {
// count attack
i.countDDOSAttack(tx)
}

if i.txpool.IsDDOSTx(tx) {
i.logger.Info("drop ddos attack contract transaction",
"address", tx.To,
"from", tx.From,
)

// drop tx
shouldDropTxs = append(shouldDropTxs, tx)

continue
Expand All @@ -963,7 +1006,8 @@ func (i *Ibft) writeTransactions(
begin := time.Now() // for duration calculation

if err := transition.Write(tx); err != nil {
i.banishLongTimeConsumingTx(tx, begin)
// mark long time consuming contract to prevent ddos attack
i.markLongTimeConsumingContract(tx, begin)

i.logger.Debug("write transaction failed", "hash", tx.Hash, "from", tx.From,
"nonce", tx.Nonce, "err", err)
Expand Down Expand Up @@ -1006,7 +1050,8 @@ func (i *Ibft) writeTransactions(

// no errors, go on
priceTxs.Shift()
i.banishLongTimeConsumingTx(tx, begin)
// mark long time consuming contract to prevent ddos attack
i.markLongTimeConsumingContract(tx, begin)

includedTransactions = append(includedTransactions, tx)
}
Expand All @@ -1024,34 +1069,36 @@ func (i *Ibft) shouldTerminate(terminalTime time.Time) bool {
return time.Now().After(terminalTime)
}

func (i *Ibft) shouldBanishTx(tx *types.Transaction) bool {
if !i.banishAbnormalContract || tx.To == nil {
func (i *Ibft) shouldMarkLongConsumingTx(tx *types.Transaction) bool {
if tx.To == nil {
return false
}

// if tx send to some banish contract, drop it
_, shouldBanish := i.exhaustingContracts[*tx.To]
_, exists := i.exhaustingContracts[*tx.To]

return exists
}

return shouldBanish
func (i *Ibft) countDDOSAttack(tx *types.Transaction) {
i.txpool.MarkDDOSTx(tx)
}

func (i *Ibft) banishLongTimeConsumingTx(tx *types.Transaction, begin time.Time) {
func (i *Ibft) markLongTimeConsumingContract(tx *types.Transaction, begin time.Time) {
duration := time.Since(begin).Milliseconds()
if duration < i.blockTime.Milliseconds() ||
tx.To == nil { // long contract creation is tolerable
// long contract creation is tolerable, long time execution is not tolerable
if tx.To == nil || duration < i.blockTime.Milliseconds() {
return
}

// banish the contract
i.exhaustingContracts[*tx.To] = struct{}{}

i.logger.Info("banish contract who consumes too many CPU time",
i.logger.Info("mark contract who consumes too many CPU or I/O time",
"duration", duration,
"from", tx.From,
"to", tx.To,
"gasPrice", tx.GasPrice,
"gas", tx.Gas,
"len", len(tx.Input),
)
}

Expand Down Expand Up @@ -1555,6 +1602,7 @@ func (i *Ibft) runRoundChangeState() {
for i.getState() == currentstate.RoundChangeState {
// timeout should update every time it enters a new round
timeout := i.state.MessageTimeout()

msg, ok := i.getNextMessage(timeout)
if !ok {
// closing
Expand Down
Loading