diff --git a/app/app.go b/app/app.go index 191ad30fd..ca82809b8 100644 --- a/app/app.go +++ b/app/app.go @@ -275,6 +275,15 @@ func NewBabylonApp( stakingKeeper := stakingkeeper.NewKeeper( appCodec, keys[stakingtypes.StoreKey], app.AccountKeeper, app.BankKeeper, app.GetSubspace(stakingtypes.ModuleName), ) + + // NOTE: the epoching module has to be set before the chekpointing module, as the checkpointing module will have access to the epoching module + epochingKeeper := epochingkeeper.NewKeeper( + appCodec, keys[epochingtypes.StoreKey], keys[epochingtypes.StoreKey], app.GetSubspace(epochingtypes.ModuleName), &app.StakingKeeper, + ) + // add msgServiceRouter so that the epoching module can forward unwrapped messages to the staking module + epochingKeeper.SetMsgServiceRouter(app.BaseApp.MsgServiceRouter()) + app.EpochingKeeper = epochingKeeper + app.MintKeeper = mintkeeper.NewKeeper( appCodec, keys[minttypes.StoreKey], app.GetSubspace(minttypes.ModuleName), &stakingKeeper, app.AccountKeeper, app.BankKeeper, authtypes.FeeCollectorName, @@ -321,15 +330,6 @@ func NewBabylonApp( ), ) - // setup epoching keeper - // NOTE: the epoching module has to be set before the chekpointing module, as the checkpointing module will have access to the epoching module - epochingKeeper := epochingkeeper.NewKeeper(appCodec, keys[epochingtypes.StoreKey], keys[epochingtypes.StoreKey], app.GetSubspace(epochingtypes.ModuleName), &app.StakingKeeper) - // TODO: add modules that need to hook onto the epoching module here - epochingKeeper.SetHooks(epochingtypes.NewMultiEpochingHooks()) - // add msgServiceRouter so that the epoching module can forward unwrapped messages to the staking module - epochingKeeper.SetMsgServiceRouter(app.BaseApp.MsgServiceRouter()) - app.EpochingKeeper = epochingKeeper - btclightclientKeeper := *btclightclientkeeper.NewKeeper(appCodec, keys[btclightclienttypes.StoreKey], keys[btclightclienttypes.MemStoreKey], app.GetSubspace(btclightclienttypes.ModuleName)) btclightclientKeeper.SetHooks(btclightclienttypes.NewMultiBTCLightClientHooks()) app.BTCLightClientKeeper = btclightclientKeeper diff --git a/app/test_helpers.go b/app/test_helpers.go index a753dbfe0..0dda0863e 100644 --- a/app/test_helpers.go +++ b/app/test_helpers.go @@ -89,7 +89,7 @@ func Setup(isCheckTx bool) *BabylonApp { // that also act as delegators. For simplicity, each validator is bonded with a delegation // of one consensus engine unit (10^6) in the default token of the babylon app from first genesis // account. A Nop logger is set in BabylonApp. -func SetupWithGenesisValSet(t *testing.T, valSet *tmtypes.ValidatorSet, genAccs []authtypes.GenesisAccount, balances ...banktypes.Balance) *BabylonApp { +func SetupWithGenesisValSet(t *testing.T, valSet *tmtypes.ValidatorSet, genAccs []authtypes.GenesisAccount, balances ...banktypes.Balance) (*BabylonApp, sdk.Context) { app, genesisState := setup(true, 5) // set genesis accounts authGenesis := authtypes.NewGenesisState(authtypes.DefaultParams(), genAccs) @@ -120,8 +120,12 @@ func SetupWithGenesisValSet(t *testing.T, valSet *tmtypes.ValidatorSet, genAccs } validators = append(validators, validator) delegations = append(delegations, stakingtypes.NewDelegation(genAccs[0].GetAddress(), val.Address.Bytes(), sdk.OneDec())) - } + + // total bond amount = bond amount * number of validators + require.Equal(t, len(validators), len(delegations)) + totalBondAmt := bondAmt.MulRaw(int64(len(validators))) + // set validators and delegations stakingGenesis := stakingtypes.NewGenesisState(stakingtypes.DefaultParams(), validators, delegations) genesisState[stakingtypes.ModuleName] = app.AppCodec().MustMarshalJSON(stakingGenesis) @@ -129,13 +133,13 @@ func SetupWithGenesisValSet(t *testing.T, valSet *tmtypes.ValidatorSet, genAccs totalSupply := sdk.NewCoins() for _, b := range balances { // add genesis acc tokens and delegated tokens to total supply - totalSupply = totalSupply.Add(b.Coins.Add(sdk.NewCoin(sdk.DefaultBondDenom, bondAmt))...) + totalSupply = totalSupply.Add(b.Coins.Add(sdk.NewCoin(sdk.DefaultBondDenom, totalBondAmt))...) } // add bonded amount to bonded pool module account balances = append(balances, banktypes.Balance{ Address: authtypes.NewModuleAddress(stakingtypes.BondedPoolName).String(), - Coins: sdk.Coins{sdk.NewCoin(sdk.DefaultBondDenom, bondAmt)}, + Coins: sdk.Coins{sdk.NewCoin(sdk.DefaultBondDenom, totalBondAmt)}, }) // update total supply @@ -156,14 +160,17 @@ func SetupWithGenesisValSet(t *testing.T, valSet *tmtypes.ValidatorSet, genAccs // commit genesis changes app.Commit() - app.BeginBlock(abci.RequestBeginBlock{Header: tmproto.Header{ - Height: app.LastBlockHeight() + 1, + height := app.LastBlockHeight() + 1 + header := tmproto.Header{ + Height: height, AppHash: app.LastCommitID().Hash, ValidatorsHash: valSet.Hash(), NextValidatorsHash: valSet.Hash(), - }}) + } + app.BeginBlock(abci.RequestBeginBlock{Header: header}) + ctx := app.BaseApp.NewContext(false, tmproto.Header{}) - return app + return app, ctx } // SetupWithGenesisAccounts initializes a new BabylonApp with the provided genesis diff --git a/testutil/datagen/priv_validator.go b/testutil/datagen/priv_validator.go new file mode 100644 index 000000000..d94055b39 --- /dev/null +++ b/testutil/datagen/priv_validator.go @@ -0,0 +1,53 @@ +package datagen + +import ( + "github.com/tendermint/tendermint/crypto" + tmproto "github.com/tendermint/tendermint/proto/tendermint/types" + tmtypes "github.com/tendermint/tendermint/types" + + cryptocodec "github.com/cosmos/cosmos-sdk/crypto/codec" + "github.com/cosmos/cosmos-sdk/crypto/keys/ed25519" + cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" +) + +// adapted from https://github.com/ClanNetwork/clan-network/blob/d00b1cc465/testutil/simapp/pv.go +// used for creating key pairs for genesis validators in tests + +var _ tmtypes.PrivValidator = PV{} + +// PV implements PrivValidator without any safety or persistence. +// Only use it for testing. +type PV struct { + PrivKey cryptotypes.PrivKey +} + +func NewPV() PV { + return PV{ed25519.GenPrivKey()} +} + +// GetPubKey implements PrivValidator interface +func (pv PV) GetPubKey() (crypto.PubKey, error) { + return cryptocodec.ToTmPubKeyInterface(pv.PrivKey.PubKey()) +} + +// SignVote implements PrivValidator interface +func (pv PV) SignVote(chainID string, vote *tmproto.Vote) error { + signBytes := tmtypes.VoteSignBytes(chainID, vote) + sig, err := pv.PrivKey.Sign(signBytes) + if err != nil { + return err + } + vote.Signature = sig + return nil +} + +// SignProposal implements PrivValidator interface +func (pv PV) SignProposal(chainID string, proposal *tmproto.Proposal) error { + signBytes := tmtypes.ProposalSignBytes(chainID, proposal) + sig, err := pv.PrivKey.Sign(signBytes) + if err != nil { + return err + } + proposal.Signature = sig + return nil +} diff --git a/x/epoching/abci.go b/x/epoching/abci.go index c7b2384ad..c5999d191 100644 --- a/x/epoching/abci.go +++ b/x/epoching/abci.go @@ -21,9 +21,10 @@ import ( func BeginBlocker(ctx sdk.Context, k keeper.Keeper, req abci.RequestBeginBlock) { defer telemetry.ModuleMeasureSince(types.ModuleName, time.Now(), telemetry.MetricKeyBeginBlocker) - // if this block is the first block of an epoch + // if this block is the first block of the next epoch // note that we haven't incremented the epoch number yet - if k.GetEpoch(ctx).IsFirstBlock(ctx) { + epoch := k.GetEpoch(ctx) + if epoch.IsFirstBlockOfNextEpoch(ctx) { // increase epoch number IncEpoch := k.IncEpoch(ctx) // init the slashed voting power of this new epoch diff --git a/x/epoching/keeper/epoch_msg_queue_test.go b/x/epoching/keeper/epoch_msg_queue_test.go new file mode 100644 index 000000000..9c9bfc25d --- /dev/null +++ b/x/epoching/keeper/epoch_msg_queue_test.go @@ -0,0 +1,52 @@ +package keeper_test + +import ( + "math/rand" + "testing" + + "github.com/babylonchain/babylon/x/epoching/types" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" +) + +func FuzzEpochMsgQueue(f *testing.F) { + f.Add(int64(11111)) + f.Add(int64(22222)) + f.Add(int64(55555)) + f.Add(int64(12312)) + + f.Fuzz(func(t *testing.T, seed int64) { + rand.Seed(seed) + + _, ctx, keeper, _, _ := setupTestKeeper() + // ensure that the epoch msg queue is correct at the genesis + require.Empty(t, keeper.GetEpochMsgs(ctx)) + require.Equal(t, uint64(0), keeper.GetQueueLength(ctx)) + + // Enqueue a random number of msgs + numQueuedMsgs := rand.Uint64() % 100 + for i := uint64(0); i < numQueuedMsgs; i++ { + msg := types.QueuedMessage{ + TxId: sdk.Uint64ToBigEndian(i), + MsgId: sdk.Uint64ToBigEndian(i), + } + keeper.EnqueueMsg(ctx, msg) + } + + // ensure that each msg in the queue is correct + epochMsgs := keeper.GetEpochMsgs(ctx) + for i, msg := range epochMsgs { + require.Equal(t, sdk.Uint64ToBigEndian(uint64(i)), msg.TxId) + require.Equal(t, sdk.Uint64ToBigEndian(uint64(i)), msg.MsgId) + require.Nil(t, msg.Msg) + } + + // after clearing the msg queue, ensure that the epoch msg queue is empty + keeper.ClearEpochMsgs(ctx) + require.Empty(t, keeper.GetEpochMsgs(ctx)) + require.Equal(t, uint64(0), keeper.GetQueueLength(ctx)) + }) +} + +// TODO (stateful tests): fuzz HandleQueueMsg. initialise some validators, let them submit some msgs and trigger HandleQueueMsg +// require mocking valid QueueMsgs diff --git a/x/epoching/keeper/epoch_slashed_val_set_test.go b/x/epoching/keeper/epoch_slashed_val_set_test.go new file mode 100644 index 000000000..800d6abca --- /dev/null +++ b/x/epoching/keeper/epoch_slashed_val_set_test.go @@ -0,0 +1,4 @@ +package keeper_test + +// TODO (stateful tests): slash some random validators and check if the resulting (slashed) validator sets are consistent or not +// require mocking slashing diff --git a/x/epoching/keeper/epoch_val_set_test.go b/x/epoching/keeper/epoch_val_set_test.go new file mode 100644 index 000000000..427f729c6 --- /dev/null +++ b/x/epoching/keeper/epoch_val_set_test.go @@ -0,0 +1,43 @@ +package keeper_test + +import ( + "math/rand" + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" +) + +func FuzzEpochValSet(f *testing.F) { + f.Add(int64(11111)) + f.Add(int64(22222)) + f.Add(int64(55555)) + f.Add(int64(12312)) + + f.Fuzz(func(t *testing.T, seed int64) { + rand.Seed(seed) + + app, ctx, keeper, _, _, valSet := setupTestKeeperWithValSet(t) + getValSet := keeper.GetValidatorSet(ctx, 0) + require.Equal(t, len(valSet.Validators), len(getValSet)) + for i := range getValSet { + require.Equal(t, sdk.ValAddress(valSet.Validators[i].Address), getValSet[i].Addr) + } + + // generate a random number of new blocks + numIncBlocks := rand.Uint64()%1000 + 1 + for i := uint64(0); i < numIncBlocks; i++ { + ctx = genAndApplyEmptyBlock(app, ctx) + } + + // check whether the validator set remains the same or not + getValSet2 := keeper.GetValidatorSet(ctx, keeper.GetEpoch(ctx).EpochNumber) + require.Equal(t, len(valSet.Validators), len(getValSet2)) + for i := range getValSet2 { + require.Equal(t, sdk.ValAddress(valSet.Validators[i].Address), getValSet[i].Addr) + } + }) +} + +// TODO (stateful tests): create some random validators and check if the resulting validator set is consistent or not +// require mocking Msg(Wrapped)CreateValidator diff --git a/x/epoching/keeper/epochs_test.go b/x/epoching/keeper/epochs_test.go new file mode 100644 index 000000000..40af22b17 --- /dev/null +++ b/x/epoching/keeper/epochs_test.go @@ -0,0 +1,47 @@ +package keeper_test + +import ( + "math/rand" + "testing" + + "github.com/babylonchain/babylon/x/epoching/types" + "github.com/stretchr/testify/require" +) + +func FuzzEpochs(f *testing.F) { + f.Add(int64(11111)) + f.Add(int64(22222)) + f.Add(int64(55555)) + f.Add(int64(12312)) + + f.Fuzz(func(t *testing.T, seed int64) { + rand.Seed(seed) + + app, ctx, keeper, _, _ := setupTestKeeper() + // ensure that the epoch info is correct at the genesis + epoch := keeper.GetEpoch(ctx) + require.Equal(t, epoch.EpochNumber, uint64(0)) + require.Equal(t, epoch.FirstBlockHeight, uint64(0)) + + // set a random epoch interval + epochInterval := rand.Uint64()%100 + 1 + keeper.SetParams(ctx, types.Params{ + EpochInterval: epochInterval, + }) + // increment a random number of new blocks + numIncBlocks := rand.Uint64()%1000 + 1 + for i := uint64(0); i < numIncBlocks; i++ { + ctx = genAndApplyEmptyBlock(app, ctx) + } + + // ensure that the epoch info is still correct + expectedEpochNumber := numIncBlocks / epochInterval + if numIncBlocks%epochInterval > 0 { + expectedEpochNumber += 1 + } + actualNewEpoch := keeper.GetEpoch(ctx) + require.Equal(t, expectedEpochNumber, actualNewEpoch.EpochNumber) + require.Equal(t, epochInterval, actualNewEpoch.CurrentEpochInterval) + require.Equal(t, (expectedEpochNumber-1)*epochInterval+1, actualNewEpoch.FirstBlockHeight) + }) +} diff --git a/x/epoching/keeper/grpc_query_test.go b/x/epoching/keeper/grpc_query_test.go index 3176ca604..dbcadb32e 100644 --- a/x/epoching/keeper/grpc_query_test.go +++ b/x/epoching/keeper/grpc_query_test.go @@ -1,40 +1,16 @@ package keeper_test import ( - "fmt" "math/rand" "testing" + "github.com/babylonchain/babylon/testutil/datagen" "github.com/babylonchain/babylon/x/epoching/types" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/types/query" "github.com/stretchr/testify/require" ) -func (suite *KeeperTestSuite) TestParamsQuery() { - ctx, queryClient := suite.ctx, suite.queryClient - req := types.QueryParamsRequest{} - - testCases := []struct { - msg string - params types.Params - }{ - { - "default params", - types.DefaultParams(), - }, - } - - for _, tc := range testCases { - suite.Run(fmt.Sprintf("Case %s", tc.msg), func() { - wctx := sdk.WrapSDKContext(ctx) - resp, err := queryClient.Params(wctx, &req) - suite.NoError(err) - suite.Equal(&types.QueryParamsResponse{Params: tc.params}, resp) - }) - } -} - // FuzzParamsQuery fuzzes queryClient.Params // 1. Generate random param // 2. When EpochInterval is 0, ensure `Validate` returns an error @@ -79,60 +55,6 @@ func FuzzParamsQuery(f *testing.F) { }) } -func (suite *KeeperTestSuite) TestCurrentEpoch() { - ctx, queryClient := suite.ctx, suite.queryClient - req := types.QueryCurrentEpochRequest{} - - testCases := []struct { - msg string - malleate func() - epochNumber uint64 - epochBoundary uint64 - }{ - { - "epoch 0", - func() {}, - 0, - 0, - }, - { - "epoch 1", - func() { - suite.keeper.IncEpoch(suite.ctx) - }, - 1, - suite.keeper.GetParams(suite.ctx).EpochInterval * 1, - }, - { - "epoch 2", - func() { - suite.keeper.IncEpoch(suite.ctx) - }, - 2, - suite.keeper.GetParams(suite.ctx).EpochInterval * 2, - }, - { - "reset to epoch 0", - func() { - suite.keeper.InitEpoch(suite.ctx) - }, - 0, - 0, - }, - } - - for _, tc := range testCases { - suite.Run(fmt.Sprintf("Case %s", tc.msg), func() { - tc.malleate() - wctx := sdk.WrapSDKContext(ctx) - resp, err := queryClient.CurrentEpoch(wctx, &req) - suite.NoError(err) - suite.Equal(tc.epochNumber, resp.CurrentEpoch) - suite.Equal(tc.epochBoundary, resp.EpochBoundary) - }) - } -} - // FuzzCurrentEpoch fuzzes queryClient.CurrentEpoch // 1. generate a random number of epochs to increment // 2. query the current epoch and boundary @@ -156,74 +78,6 @@ func FuzzCurrentEpoch(f *testing.F) { }) } -func (suite *KeeperTestSuite) TestEpochMsgs() { - ctx, queryClient := suite.ctx, suite.queryClient - wctx := sdk.WrapSDKContext(ctx) - req := &types.QueryEpochMsgsRequest{ - Pagination: &query.PageRequest{ - Limit: 100, - }, - } - - testCases := []struct { - msg string - malleate func() - epochMsgs []*types.QueuedMessage - }{ - { - "empty epoch msgs", - func() {}, - []*types.QueuedMessage{}, - }, - { - "newly inserted epoch msg", - func() { - msg := types.QueuedMessage{ - TxId: []byte{0x01}, - } - suite.keeper.EnqueueMsg(suite.ctx, msg) - }, - []*types.QueuedMessage{ - {TxId: []byte{0x01}}, - }, - }, - { - "newly inserted epoch msg", - func() { - msg := types.QueuedMessage{ - TxId: []byte{0x02}, - } - suite.keeper.EnqueueMsg(suite.ctx, msg) - }, - []*types.QueuedMessage{ - {TxId: []byte{0x01}}, - {TxId: []byte{0x02}}, - }, - }, - { - "cleared epoch msg", - func() { - suite.keeper.ClearEpochMsgs(suite.ctx) - }, - []*types.QueuedMessage{}, - }, - } - - for _, tc := range testCases { - suite.Run(fmt.Sprintf("Case %s", tc.msg), func() { - tc.malleate() - resp, err := queryClient.EpochMsgs(wctx, req) - suite.NoError(err) - suite.Equal(len(tc.epochMsgs), len(resp.Msgs)) - suite.Equal(uint64(len(tc.epochMsgs)), suite.keeper.GetQueueLength(suite.ctx)) - for idx := range tc.epochMsgs { - suite.Equal(tc.epochMsgs[idx].MsgId, resp.Msgs[idx].MsgId) - suite.Equal(tc.epochMsgs[idx].TxId, resp.Msgs[idx].TxId) - } - }) - } -} - // FuzzEpochMsgs fuzzes queryClient.EpochMsgs // 1. randomly generate msgs and limit in pagination // 2. check the returned msg was previously enqueued @@ -244,7 +98,7 @@ func FuzzEpochMsgs(f *testing.F) { wctx := sdk.WrapSDKContext(ctx) // enque a random number of msgs with random txids for i := uint64(0); i < numMsgs; i++ { - txid := genRandomByteSlice(32) + txid := datagen.GenRandomByteArray(32) txidsMap[string(txid)] = true keeper.EnqueueMsg(ctx, types.QueuedMessage{TxId: txid}) } diff --git a/x/epoching/keeper/hooks.go b/x/epoching/keeper/hooks.go index 63f1dd506..9bbd313bd 100644 --- a/x/epoching/keeper/hooks.go +++ b/x/epoching/keeper/hooks.go @@ -8,18 +8,9 @@ import ( stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" ) -// Wrapper struct -type Hooks struct { - k Keeper -} - -// Implements StakingHooks/EpochingHooks interfaces -var _ stakingtypes.StakingHooks = Hooks{} +// ensures Keeper implements EpochingHooks interfaces var _ types.EpochingHooks = Keeper{} -// Create new distribution hooks -func (k Keeper) Hooks() Hooks { return Hooks{k} } - // AfterEpochBegins - call hook if registered func (k Keeper) AfterEpochBegins(ctx sdk.Context, epoch uint64) { if k.hooks != nil { @@ -41,6 +32,17 @@ func (k Keeper) BeforeSlashThreshold(ctx sdk.Context, valAddrs []sdk.ValAddress) } } +// Wrapper struct +type Hooks struct { + k Keeper +} + +// ensures Hooks implements StakingHooks interfaces +var _ stakingtypes.StakingHooks = Hooks{} + +// Create new distribution hooks +func (k Keeper) Hooks() Hooks { return Hooks{k} } + // BeforeValidatorSlashed records the slash event func (h Hooks) BeforeValidatorSlashed(ctx sdk.Context, valAddr sdk.ValAddress, fraction sdk.Dec) { thresholds := []float64{float64(1) / float64(3), float64(2) / float64(3)} diff --git a/x/epoching/keeper/keeper_test.go b/x/epoching/keeper/keeper_test.go index ead1e28b6..6c64f9fff 100644 --- a/x/epoching/keeper/keeper_test.go +++ b/x/epoching/keeper/keeper_test.go @@ -1,32 +1,79 @@ package keeper_test import ( - "crypto/rand" "testing" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" - "github.com/cosmos/cosmos-sdk/baseapp" - tmproto "github.com/tendermint/tendermint/proto/tendermint/types" - "github.com/babylonchain/babylon/app" + "github.com/babylonchain/babylon/testutil/datagen" "github.com/babylonchain/babylon/x/epoching/keeper" "github.com/babylonchain/babylon/x/epoching/types" + "github.com/cosmos/cosmos-sdk/baseapp" + "github.com/cosmos/cosmos-sdk/crypto/keys/secp256k1" sdk "github.com/cosmos/cosmos-sdk/types" + authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" + banktypes "github.com/cosmos/cosmos-sdk/x/bank/types" + stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" + abci "github.com/tendermint/tendermint/abci/types" + "github.com/tendermint/tendermint/crypto/merkle" + tmproto "github.com/tendermint/tendermint/proto/tendermint/types" + tmtypes "github.com/tendermint/tendermint/types" ) -type KeeperTestSuite struct { - suite.Suite +func setupTestKeeperWithValSet(t *testing.T) (*app.BabylonApp, sdk.Context, *keeper.Keeper, types.MsgServer, types.QueryClient, *tmtypes.ValidatorSet) { + // generate the validator set with 10 validators + vals := []*tmtypes.Validator{} + for i := 0; i < 10; i++ { + privVal := datagen.NewPV() + pubKey, err := privVal.GetPubKey() + require.NoError(t, err) + val := tmtypes.NewValidator(pubKey, 1) + vals = append(vals, val) + } + valSet := tmtypes.NewValidatorSet(vals) + + // generate the genesis account + senderPrivKey := secp256k1.GenPrivKey() + acc := authtypes.NewBaseAccount(senderPrivKey.PubKey().Address().Bytes(), senderPrivKey.PubKey(), 0, 0) + balance := banktypes.Balance{ + Address: acc.GetAddress().String(), + Coins: sdk.NewCoins(sdk.NewCoin(sdk.DefaultBondDenom, sdk.NewInt(100000000000000))), + } - app *app.BabylonApp - ctx sdk.Context - keeper *keeper.Keeper - msgSrvr types.MsgServer - queryClient types.QueryClient + app, ctx := app.SetupWithGenesisValSet(t, valSet, []authtypes.GenesisAccount{acc}, balance) + + epochingKeeper := app.EpochingKeeper + querier := keeper.Querier{Keeper: epochingKeeper} + queryHelper := baseapp.NewQueryServerTestHelper(ctx, app.InterfaceRegistry()) + types.RegisterQueryServer(queryHelper, querier) + queryClient := types.NewQueryClient(queryHelper) + + msgSrvr := keeper.NewMsgServerImpl(epochingKeeper) + + return app, ctx, &epochingKeeper, msgSrvr, queryClient, valSet +} + +func genAndApplyEmptyBlock(app *app.BabylonApp, ctx sdk.Context) sdk.Context { + newHeight := app.LastBlockHeight() + 1 + valSet := app.StakingKeeper.GetLastValidators(ctx) + valhash := calculateValHash(valSet) + newHeader := tmproto.Header{ + Height: newHeight, + AppHash: app.LastCommitID().Hash, + ValidatorsHash: valhash, + NextValidatorsHash: valhash, + } + + app.BeginBlock(abci.RequestBeginBlock{Header: newHeader}) + app.EndBlock(abci.RequestEndBlock{Height: newHeight}) + app.Commit() + + return ctx.WithBlockHeader(newHeader) } -// setupTestKeeper creates a new server +// setupTestKeeper creates a simulated Babylon app func setupTestKeeper() (*app.BabylonApp, sdk.Context, *keeper.Keeper, types.MsgServer, types.QueryClient) { app := app.Setup(false) ctx := app.BaseApp.NewContext(false, tmproto.Header{}) @@ -42,6 +89,16 @@ func setupTestKeeper() (*app.BabylonApp, sdk.Context, *keeper.Keeper, types.MsgS return app, ctx, &epochingKeeper, msgSrvr, queryClient } +type KeeperTestSuite struct { + suite.Suite + + app *app.BabylonApp + ctx sdk.Context + keeper *keeper.Keeper + msgSrvr types.MsgServer + queryClient types.QueryClient +} + func (suite *KeeperTestSuite) SetupTest() { suite.app, suite.ctx, suite.keeper, suite.msgSrvr, suite.queryClient = setupTestKeeper() } @@ -67,15 +124,20 @@ func TestKeeperTestSuite(t *testing.T) { suite.Run(t, new(KeeperTestSuite)) } -func genRandomByteSlice(length uint64) []byte { - r := make([]byte, length) - rand.Read(r) - return r -} - func min(a, b uint64) uint64 { if a < b { return a } return b } + +// calculate validator hash and new header +// (adapted from https://github.com/cosmos/cosmos-sdk/blob/v0.45.5/simapp/test_helpers.go#L156-L163) +func calculateValHash(valSet []stakingtypes.Validator) []byte { + bzs := make([][]byte, len(valSet)) + for i, val := range valSet { + consAddr, _ := val.GetConsAddr() + bzs[i] = consAddr + } + return merkle.HashFromByteSlices(bzs) +} diff --git a/x/epoching/keeper/msg_server.go b/x/epoching/keeper/msg_server.go index c0b9d5fae..f1e53f3b2 100644 --- a/x/epoching/keeper/msg_server.go +++ b/x/epoching/keeper/msg_server.go @@ -25,20 +25,12 @@ var _ types.MsgServer = msgServer{} // WrappedDelegate handles the MsgWrappedDelegate request func (k msgServer) WrappedDelegate(goCtx context.Context, msg *types.MsgWrappedDelegate) (*types.MsgWrappedDelegateResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) - msgBytes, err := k.cdc.Marshal(msg) + txid := tmhash.Sum(ctx.TxBytes()) + queuedMsg, err := types.NewQueuedMessage(txid, msg) if err != nil { return nil, err } - // wrapped -> unwrapped -> QueuedMessage - queuedMsg := types.QueuedMessage{ - TxId: tmhash.Sum(ctx.TxBytes()), - MsgId: tmhash.Sum(msgBytes), - Msg: &types.QueuedMessage_MsgDelegate{ - MsgDelegate: msg.Msg, - }, - } - k.EnqueueMsg(ctx, queuedMsg) ctx.EventManager().EmitEvents(sdk.Events{ sdk.NewEvent( @@ -60,20 +52,12 @@ func (k msgServer) WrappedDelegate(goCtx context.Context, msg *types.MsgWrappedD // WrappedUndelegate handles the MsgWrappedUndelegate request func (k msgServer) WrappedUndelegate(goCtx context.Context, msg *types.MsgWrappedUndelegate) (*types.MsgWrappedUndelegateResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) - msgBytes, err := k.cdc.Marshal(msg) + txid := tmhash.Sum(ctx.TxBytes()) + queuedMsg, err := types.NewQueuedMessage(txid, msg) if err != nil { return nil, err } - // wrapped -> unwrapped -> QueuedMessage - queuedMsg := types.QueuedMessage{ - TxId: tmhash.Sum(ctx.TxBytes()), - MsgId: tmhash.Sum(msgBytes), - Msg: &types.QueuedMessage_MsgUndelegate{ - MsgUndelegate: msg.Msg, - }, - } - k.EnqueueMsg(ctx, queuedMsg) ctx.EventManager().EmitEvents(sdk.Events{ sdk.NewEvent( @@ -95,20 +79,12 @@ func (k msgServer) WrappedUndelegate(goCtx context.Context, msg *types.MsgWrappe // WrappedBeginRedelegate handles the MsgWrappedBeginRedelegate request func (k msgServer) WrappedBeginRedelegate(goCtx context.Context, msg *types.MsgWrappedBeginRedelegate) (*types.MsgWrappedBeginRedelegateResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) - msgBytes, err := k.cdc.Marshal(msg) + txid := tmhash.Sum(ctx.TxBytes()) + queuedMsg, err := types.NewQueuedMessage(txid, msg) if err != nil { return nil, err } - // wrapped -> unwrapped -> QueuedMessage - queuedMsg := types.QueuedMessage{ - TxId: tmhash.Sum(ctx.TxBytes()), - MsgId: tmhash.Sum(msgBytes), - Msg: &types.QueuedMessage_MsgBeginRedelegate{ - MsgBeginRedelegate: msg.Msg, - }, - } - // enqueue msg k.EnqueueMsg(ctx, queuedMsg) // emit event diff --git a/x/epoching/types/epoching.go b/x/epoching/types/epoching.go index f4eb5ec25..43c17c085 100644 --- a/x/epoching/types/epoching.go +++ b/x/epoching/types/epoching.go @@ -2,6 +2,7 @@ package types import ( sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/tendermint/tendermint/crypto/tmhash" ) func (e Epoch) GetLastBlockHeight() uint64 { @@ -29,3 +30,54 @@ func (e Epoch) IsFirstBlock(ctx sdk.Context) bool { func (e Epoch) IsSecondBlock(ctx sdk.Context) bool { return e.GetSecondBlockHeight() == uint64(ctx.BlockHeight()) } + +func (e Epoch) IsFirstBlockOfNextEpoch(ctx sdk.Context) bool { + if e.EpochNumber == 0 { + return ctx.BlockHeight() == 1 + } else { + height := uint64(ctx.BlockHeight()) + return e.FirstBlockHeight+e.CurrentEpochInterval == height + } +} + +// NewQueuedMessage creates a new QueuedMessage from a wrapped msg +// i.e., wrapped -> unwrapped -> QueuedMessage +func NewQueuedMessage(txid []byte, msg sdk.Msg) (QueuedMessage, error) { + // marshal the actual msg (MsgDelegate, MsgBeginRedelegate, MsgUndelegate, ...) inside isQueuedMessage_Msg + // TODO (non-urgent): after we bump to Cosmos SDK v0.46, add MsgCancelUnbondingDelegation + var qmsg isQueuedMessage_Msg + var msgBytes []byte + var err error + switch msg := msg.(type) { + case *MsgWrappedDelegate: + if msgBytes, err = msg.Msg.Marshal(); err != nil { + return QueuedMessage{}, err + } + qmsg = &QueuedMessage_MsgDelegate{ + MsgDelegate: msg.Msg, + } + case *MsgWrappedBeginRedelegate: + if msgBytes, err = msg.Msg.Marshal(); err != nil { + return QueuedMessage{}, err + } + qmsg = &QueuedMessage_MsgBeginRedelegate{ + MsgBeginRedelegate: msg.Msg, + } + case *MsgWrappedUndelegate: + if msgBytes, err = msg.Msg.Marshal(); err != nil { + return QueuedMessage{}, err + } + qmsg = &QueuedMessage_MsgUndelegate{ + MsgUndelegate: msg.Msg, + } + default: + return QueuedMessage{}, ErrUnwrappedMsgType + } + + queuedMsg := QueuedMessage{ + TxId: txid, + MsgId: tmhash.Sum(msgBytes), + Msg: qmsg, + } + return queuedMsg, nil +} diff --git a/x/epoching/types/validator.go b/x/epoching/types/validator.go index df27e299b..9ef6c518e 100644 --- a/x/epoching/types/validator.go +++ b/x/epoching/types/validator.go @@ -1,9 +1,10 @@ package types import ( + "sort" + sdk "github.com/cosmos/cosmos-sdk/types" "github.com/pkg/errors" - "sort" ) type Validator struct {