From 86673984820aa9dfc8e64984608a52e846ba369f Mon Sep 17 00:00:00 2001 From: Runchao Han Date: Fri, 8 Jul 2022 23:30:32 +1000 Subject: [PATCH] epoching: event and hook upon a certain threshold amount of slashed voting power (#35) --- app/app.go | 2 +- x/epoching/abci.go | 4 + x/epoching/genesis.go | 8 +- x/epoching/keeper/epoch_msg_queue.go | 139 +++++++++++ x/epoching/keeper/epoch_slashed_val_set.go | 128 ++++++++++ x/epoching/keeper/epoch_val_set.go | 154 ++++++++++++ x/epoching/keeper/epochs.go | 74 ++++++ x/epoching/keeper/grpc_query_test.go | 264 ++++++++++++++++++++- x/epoching/keeper/hooks.go | 82 ++++++- x/epoching/keeper/keeper.go | 170 ------------- x/epoching/keeper/keeper_test.go | 80 +++++++ x/epoching/keeper/msg_server.go | 49 +++- x/epoching/keeper/msg_server_test.go | 112 ++++++++- x/epoching/types/errors.go | 14 +- x/epoching/types/events.go | 32 ++- x/epoching/types/expected_keepers.go | 16 +- x/epoching/types/genesis_test.go | 15 +- x/epoching/types/hooks.go | 6 + x/epoching/types/keys.go | 10 +- x/epoching/types/msg.go | 27 ++- x/epoching/types/msg_test.go | 166 +++++++++++++ 21 files changed, 1314 insertions(+), 238 deletions(-) create mode 100644 x/epoching/keeper/epoch_msg_queue.go create mode 100644 x/epoching/keeper/epoch_slashed_val_set.go create mode 100644 x/epoching/keeper/epoch_val_set.go create mode 100644 x/epoching/keeper/epochs.go create mode 100644 x/epoching/types/msg_test.go diff --git a/app/app.go b/app/app.go index 1de42854e..8ed6825fd 100644 --- a/app/app.go +++ b/app/app.go @@ -296,7 +296,7 @@ func NewBabylonApp( // register the staking hooks // NOTE: stakingKeeper above is passed by reference, so that it will contain these hooks app.StakingKeeper = *stakingKeeper.SetHooks( - stakingtypes.NewMultiStakingHooks(app.DistrKeeper.Hooks(), app.SlashingKeeper.Hooks()), + stakingtypes.NewMultiStakingHooks(app.DistrKeeper.Hooks(), app.SlashingKeeper.Hooks(), app.EpochingKeeper.Hooks()), ) app.AuthzKeeper = authzkeeper.NewKeeper(keys[authzkeeper.StoreKey], appCodec, app.BaseApp.MsgServiceRouter()) diff --git a/x/epoching/abci.go b/x/epoching/abci.go index a6b9251f6..afbb1cab0 100644 --- a/x/epoching/abci.go +++ b/x/epoching/abci.go @@ -27,6 +27,10 @@ func BeginBlocker(ctx sdk.Context, k keeper.Keeper, req abci.RequestBeginBlock) if uint64(ctx.BlockHeight())-1 == epochBoundary.Uint64() { // increase epoch number incEpochNumber := k.IncEpochNumber(ctx) + // init the slashed voting power of this new epoch + k.InitSlashedVotingPower(ctx) + // store the current validator set + k.InitValidatorSet(ctx) // trigger AfterEpochBegins hook k.AfterEpochBegins(ctx, incEpochNumber) // emit BeginEpoch event diff --git a/x/epoching/genesis.go b/x/epoching/genesis.go index f17e61dd2..a7269fe02 100644 --- a/x/epoching/genesis.go +++ b/x/epoching/genesis.go @@ -12,9 +12,13 @@ func InitGenesis(ctx sdk.Context, k keeper.Keeper, genState types.GenesisState) // set params for this module k.SetParams(ctx, genState.Params) // init epoch number - k.SetEpochNumber(ctx, sdk.NewUint(0)) + k.InitEpochNumber(ctx) // init msg queue length - k.SetQueueLength(ctx, sdk.NewUint(0)) + k.InitQueueLength(ctx) + // init validator set + k.InitValidatorSet(ctx) + // init slashed voting power + k.InitSlashedVotingPower(ctx) } // ExportGenesis returns the capability module's exported genesis. diff --git a/x/epoching/keeper/epoch_msg_queue.go b/x/epoching/keeper/epoch_msg_queue.go new file mode 100644 index 000000000..ac1a6eb49 --- /dev/null +++ b/x/epoching/keeper/epoch_msg_queue.go @@ -0,0 +1,139 @@ +package keeper + +import ( + "github.com/babylonchain/babylon/x/epoching/types" + "github.com/cosmos/cosmos-sdk/store/prefix" + sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" +) + +// InitQueueLength initialises the msg queue length to 0 +func (k Keeper) InitQueueLength(ctx sdk.Context) { + store := ctx.KVStore(k.storeKey) + + queueLenBytes, err := sdk.NewUint(0).Marshal() + if err != nil { + panic(sdkerrors.Wrap(types.ErrMarshal, err.Error())) + } + + store.Set(types.QueueLengthKey, queueLenBytes) +} + +// GetQueueLength fetches the number of queued messages +func (k Keeper) GetQueueLength(ctx sdk.Context) sdk.Uint { + store := ctx.KVStore(k.storeKey) + + // get queue len in bytes from DB + bz := store.Get(types.QueueLengthKey) + if bz == nil { + panic(types.ErrUnknownQueueLen) + } + // unmarshal + var queueLen sdk.Uint + if err := queueLen.Unmarshal(bz); err != nil { + panic(sdkerrors.Wrap(types.ErrUnmarshal, err.Error())) + } + + return queueLen +} + +// setQueueLength sets the msg queue length +func (k Keeper) setQueueLength(ctx sdk.Context, queueLen sdk.Uint) { + store := ctx.KVStore(k.storeKey) + + queueLenBytes, err := queueLen.Marshal() + if err != nil { + panic(sdkerrors.Wrap(types.ErrMarshal, err.Error())) + } + + store.Set(types.QueueLengthKey, queueLenBytes) +} + +// incQueueLength adds the queue length by 1 +func (k Keeper) incQueueLength(ctx sdk.Context) { + queueLen := k.GetQueueLength(ctx) + incrementedQueueLen := queueLen.AddUint64(1) + k.setQueueLength(ctx, incrementedQueueLen) +} + +// EnqueueMsg enqueues a message to the queue of the current epoch +func (k Keeper) EnqueueMsg(ctx sdk.Context, msg types.QueuedMessage) { + // prefix: QueuedMsgKey + store := ctx.KVStore(k.storeKey) + queuedMsgStore := prefix.NewStore(store, types.QueuedMsgKey) + + // key: queueLenBytes + queueLen := k.GetQueueLength(ctx) + queueLenBytes, err := queueLen.Marshal() + if err != nil { + panic(sdkerrors.Wrap(types.ErrMarshal, err.Error())) + } + // value: msgBytes + msgBytes, err := k.cdc.Marshal(&msg) + if err != nil { + panic(sdkerrors.Wrap(types.ErrMarshal, err.Error())) + } + queuedMsgStore.Set(queueLenBytes, msgBytes) + + // increment queue length + k.incQueueLength(ctx) +} + +// GetEpochMsgs returns the set of messages queued in the current epoch +func (k Keeper) GetEpochMsgs(ctx sdk.Context) []*types.QueuedMessage { + queuedMsgs := []*types.QueuedMessage{} + store := ctx.KVStore(k.storeKey) + + // add each queued msg to queuedMsgs + iterator := sdk.KVStorePrefixIterator(store, types.QueuedMsgKey) + defer iterator.Close() + for ; iterator.Valid(); iterator.Next() { + queuedMsgBytes := iterator.Value() + var queuedMsg types.QueuedMessage + if err := k.cdc.Unmarshal(queuedMsgBytes, &queuedMsg); err != nil { + panic(sdkerrors.Wrap(types.ErrUnmarshal, err.Error())) + } + queuedMsgs = append(queuedMsgs, &queuedMsg) + } + + return queuedMsgs +} + +// ClearEpochMsgs removes all messages in the queue +func (k Keeper) ClearEpochMsgs(ctx sdk.Context) { + store := ctx.KVStore(k.storeKey) + + // remove all epoch msgs + iterator := sdk.KVStorePrefixIterator(store, types.QueuedMsgKey) + defer iterator.Close() + for ; iterator.Valid(); iterator.Next() { + key := iterator.Key() + store.Delete(key) + } + + // set queue len to zero + k.setQueueLength(ctx, sdk.NewUint(0)) +} + +// HandleQueuedMsg unwraps a QueuedMessage and forwards it to the staking module +func (k Keeper) HandleQueuedMsg(ctx sdk.Context, msg *types.QueuedMessage) (*sdk.Result, error) { + var unwrappedMsgWithType sdk.Msg + // TODO: after we bump to Cosmos SDK v0.46, add MsgCancelUnbondingDelegation + switch unwrappedMsg := msg.Msg.(type) { + case *types.QueuedMessage_MsgCreateValidator: + unwrappedMsgWithType = unwrappedMsg.MsgCreateValidator + case *types.QueuedMessage_MsgDelegate: + unwrappedMsgWithType = unwrappedMsg.MsgDelegate + case *types.QueuedMessage_MsgUndelegate: + unwrappedMsgWithType = unwrappedMsg.MsgUndelegate + case *types.QueuedMessage_MsgBeginRedelegate: + unwrappedMsgWithType = unwrappedMsg.MsgBeginRedelegate + default: + panic(sdkerrors.Wrap(types.ErrInvalidQueuedMessageType, msg.String())) + } + + // get the handler function from router + handler := k.router.Handler(unwrappedMsgWithType) + // handle the unwrapped message + return handler(ctx, unwrappedMsgWithType) +} diff --git a/x/epoching/keeper/epoch_slashed_val_set.go b/x/epoching/keeper/epoch_slashed_val_set.go new file mode 100644 index 000000000..1a209bf57 --- /dev/null +++ b/x/epoching/keeper/epoch_slashed_val_set.go @@ -0,0 +1,128 @@ +package keeper + +import ( + "github.com/babylonchain/babylon/x/epoching/types" + "github.com/cosmos/cosmos-sdk/store/prefix" + sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" +) + +// setSlashedVotingPower sets the total amount of voting power that has been slashed in the epoch +func (k Keeper) setSlashedVotingPower(ctx sdk.Context, epochNumber sdk.Uint, power int64) { + store := k.slashedVotingPowerStore(ctx) + + // key: epochNumber + epochNumberBytes, err := epochNumber.Marshal() + if err != nil { + panic(sdkerrors.Wrap(types.ErrMarshal, err.Error())) + } + // value: power + powerBytes, err := sdk.NewInt(power).Marshal() + if err != nil { + panic(sdkerrors.Wrap(types.ErrMarshal, err.Error())) + } + + store.Set(epochNumberBytes, powerBytes) +} + +// InitSlashedVotingPower sets the slashed voting power of the current epoch to 0 +// This is called upon initialising the genesis state and upon a new epoch +func (k Keeper) InitSlashedVotingPower(ctx sdk.Context) { + epochNumber := k.GetEpochNumber(ctx) + k.setSlashedVotingPower(ctx, epochNumber, 0) +} + +// GetSlashedVotingPower fetches the amount of slashed voting power of a given epoch +func (k Keeper) GetSlashedVotingPower(ctx sdk.Context, epochNumber sdk.Uint) int64 { + store := k.slashedVotingPowerStore(ctx) + + // key: epochNumber + epochNumberBytes, err := epochNumber.Marshal() + if err != nil { + panic(sdkerrors.Wrap(types.ErrMarshal, err.Error())) + } + bz := store.Get(epochNumberBytes) + if bz == nil { + panic(types.ErrUnknownSlashedVotingPower) + } + // get value + var slashedVotingPower sdk.Int + if err := slashedVotingPower.Unmarshal(bz); err != nil { + panic(sdkerrors.Wrap(types.ErrUnmarshal, err.Error())) + } + + return slashedVotingPower.Int64() +} + +// AddSlashedValidator adds a slashed validator to the set of the current epoch +// This is called upon hook `BeforeValidatorSlashed` exposed by the staking module +func (k Keeper) AddSlashedValidator(ctx sdk.Context, valAddr sdk.ValAddress) { + epochNumber := k.GetEpochNumber(ctx) + store := k.slashedValSetStore(ctx, epochNumber) + + // insert into "set of slashed addresses" as KV pair, where + // - key: valAddr + // - value: empty + store.Set(valAddr, []byte{}) + + // add voting power + slashedVotingPower := k.GetSlashedVotingPower(ctx, epochNumber) + thisVotingPower := k.GetValidatorVotingPower(ctx, epochNumber, valAddr) + k.setSlashedVotingPower(ctx, epochNumber, slashedVotingPower+thisVotingPower) +} + +// GetSlashedValidators returns the set of slashed validators of a given epoch +func (k Keeper) GetSlashedValidators(ctx sdk.Context, epochNumber sdk.Uint) []sdk.ValAddress { + addrs := []sdk.ValAddress{} + store := k.slashedValSetStore(ctx, epochNumber) + // add each valAddr, which is the key + iterator := store.Iterator(nil, nil) + defer iterator.Close() + for ; iterator.Valid(); iterator.Next() { + addr := sdk.ValAddress(iterator.Key()) + addrs = append(addrs, addr) + } + + return addrs +} + +// ClearSlashedValidators removes all slashed validators in the set +// TODO: This is called upon the epoch is checkpointed +func (k Keeper) ClearSlashedValidators(ctx sdk.Context, epochNumber sdk.Uint) { + // prefix : SlashedValidatorSetKey || epochNumber + store := k.slashedValSetStore(ctx, epochNumber) + + // remove all entries with this prefix + iterator := store.Iterator(nil, nil) + defer iterator.Close() + for ; iterator.Valid(); iterator.Next() { + key := iterator.Key() + store.Delete(key) + } + + // forget the slashed voting power of this epoch + epochNumberBytes, err := epochNumber.Marshal() + if err != nil { + panic(sdkerrors.Wrap(types.ErrMarshal, err.Error())) + } + k.slashedVotingPowerStore(ctx).Delete(epochNumberBytes) +} + +// slashedValSetStore returns the KVStore of the slashed validator set for a given epoch +// prefix : SlashedValidatorSetKey || epochNumber +func (k Keeper) slashedValSetStore(ctx sdk.Context, epochNumber sdk.Uint) prefix.Store { + store := ctx.KVStore(k.storeKey) + slashedValStore := prefix.NewStore(store, types.SlashedValidatorSetKey) + epochNumberBytes, err := epochNumber.Marshal() + if err != nil { + panic(sdkerrors.Wrap(types.ErrMarshal, err.Error())) + } + return prefix.NewStore(slashedValStore, epochNumberBytes) +} + +// slashedVotingPower returns the KVStore of the slashed voting power +// prefix: SlashedVotingPowerKey +func (k Keeper) slashedVotingPowerStore(ctx sdk.Context) prefix.Store { + store := ctx.KVStore(k.storeKey) + return prefix.NewStore(store, types.SlashedVotingPowerKey) +} diff --git a/x/epoching/keeper/epoch_val_set.go b/x/epoching/keeper/epoch_val_set.go new file mode 100644 index 000000000..3791e115d --- /dev/null +++ b/x/epoching/keeper/epoch_val_set.go @@ -0,0 +1,154 @@ +package keeper + +import ( + "github.com/babylonchain/babylon/x/epoching/types" + "github.com/cosmos/cosmos-sdk/store/prefix" + sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" +) + +// GetValidatorSet returns the set of validators of a given epoch +func (k Keeper) GetValidatorSet(ctx sdk.Context, epochNumber sdk.Uint) map[string]int64 { + valSet := make(map[string]int64) + store := k.valSetStore(ctx, epochNumber) + iterator := store.Iterator(nil, nil) + defer iterator.Close() + for ; iterator.Valid(); iterator.Next() { + addr := string(iterator.Key()) + powerBytes := iterator.Value() + var power sdk.Int + if err := power.Unmarshal(powerBytes); err != nil { + panic(sdkerrors.Wrap(types.ErrUnmarshal, err.Error())) + } + valSet[addr] = power.Int64() + } + + return valSet +} + +// InitValidatorSet stores the validator set in the beginning of the current epoch +// This is called upon BeginBlock +func (k Keeper) InitValidatorSet(ctx sdk.Context) { + epochNumber := k.GetEpochNumber(ctx) + store := k.valSetStore(ctx, epochNumber) + totalPower := int64(0) + + // store the validator set + valSet, err := k.getValSetFromStaking(ctx) + if err != nil { + panic(err) + } + for addr, power := range valSet { + addrBytes := []byte(addr) + powerBytes, err := sdk.NewInt(power).Marshal() + if err != nil { + panic(sdkerrors.Wrap(types.ErrMarshal, err.Error())) + } + store.Set(addrBytes, powerBytes) + totalPower += power + } + // store total voting power of this validator set + epochNumberBytes, err := epochNumber.Marshal() + if err != nil { + panic(sdkerrors.Wrap(types.ErrMarshal, err.Error())) + } + totalPowerBytes, err := sdk.NewInt(totalPower).Marshal() + if err != nil { + panic(sdkerrors.Wrap(types.ErrMarshal, err.Error())) + } + k.votingPowerStore(ctx).Set(epochNumberBytes, totalPowerBytes) +} + +// ClearValidatorSet removes the validator set of a given epoch +// TODO: This is called upon the epoch is checkpointed +func (k Keeper) ClearValidatorSet(ctx sdk.Context, epochNumber sdk.Uint) { + store := k.valSetStore(ctx, epochNumber) + iterator := store.Iterator(nil, nil) + defer iterator.Close() + // clear the validator set + for ; iterator.Valid(); iterator.Next() { + key := iterator.Key() + store.Delete(key) + } + // clear total voting power of this validator set + powerStore := k.votingPowerStore(ctx) + epochNumberBytes, err := epochNumber.Marshal() + if err != nil { + panic(sdkerrors.Wrap(types.ErrMarshal, err.Error())) + } + powerStore.Delete(epochNumberBytes) +} + +// GetValidatorVotingPower returns the voting power of a given validator in a given epoch +func (k Keeper) GetValidatorVotingPower(ctx sdk.Context, epochNumber sdk.Uint, valAddr sdk.ValAddress) int64 { + store := k.valSetStore(ctx, epochNumber) + + powerBytes := store.Get(valAddr) + if powerBytes == nil { + panic(types.ErrUnknownValidator) + } + var power sdk.Int + if err := power.Unmarshal(powerBytes); err != nil { + panic(sdkerrors.Wrap(types.ErrUnmarshal, err.Error())) + } + + return power.Int64() +} + +// GetTotalVotingPower returns the total voting power of a given epoch +func (k Keeper) GetTotalVotingPower(ctx sdk.Context, epochNumber sdk.Uint) int64 { + epochNumberBytes, err := epochNumber.Marshal() + if err != nil { + panic(sdkerrors.Wrap(types.ErrMarshal, err.Error())) + } + store := k.votingPowerStore(ctx) + powerBytes := store.Get(epochNumberBytes) + if powerBytes == nil { + panic(types.ErrUnknownTotalVotingPower) + } + var power sdk.Int + if err := power.Unmarshal(powerBytes); err != nil { + panic(sdkerrors.Wrap(types.ErrUnmarshal, err.Error())) + } + return power.Int64() +} + +// valSetStore returns the KVStore of the validator set of a given epoch +// prefix: ValidatorSetKey || epochNumber +// key: string(address) +// value: voting power (in int64 as per Cosmos SDK) +func (k Keeper) valSetStore(ctx sdk.Context, epochNumber sdk.Uint) prefix.Store { + store := ctx.KVStore(k.storeKey) + valSetStore := prefix.NewStore(store, types.ValidatorSetKey) + epochNumberBytes, err := epochNumber.Marshal() + if err != nil { + panic(sdkerrors.Wrap(types.ErrMarshal, err.Error())) + } + return prefix.NewStore(valSetStore, epochNumberBytes) +} + +// votingPowerStore returns the total voting power of the validator set of a give nepoch +// prefix: ValidatorSetKey +// key: epochNumber +// value: total voting power (in int64 as per Cosmos SDK) +func (k Keeper) votingPowerStore(ctx sdk.Context) prefix.Store { + store := ctx.KVStore(k.storeKey) + return prefix.NewStore(store, types.VotingPowerKey) +} + +// get the last validator set +// key: string(address) +// value: voting power (in int64 as per Cosmos SDK) +// This is called upon BeginEpoch +// (mostly adapted from https://github.com/cosmos/cosmos-sdk/blob/v0.45.5/x/staking/keeper/val_state_change.go#L348-L373) +func (k Keeper) getValSetFromStaking(ctx sdk.Context) (map[string]int64, error) { + valSet := make(map[string]int64) + + k.stk.IterateLastValidatorPowers(ctx, func(addr sdk.ValAddress, power int64) (stop bool) { + valAddrStr := addr.String() + valSet[valAddrStr] = power + return false + }) + + return valSet, nil +} diff --git a/x/epoching/keeper/epochs.go b/x/epoching/keeper/epochs.go new file mode 100644 index 000000000..7a271fb8f --- /dev/null +++ b/x/epoching/keeper/epochs.go @@ -0,0 +1,74 @@ +package keeper + +import ( + "github.com/babylonchain/babylon/x/epoching/types" + sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" +) + +const ( + DefaultEpochNumber = 0 +) + +// setEpochNumber sets epoch number +func (k Keeper) InitEpochNumber(ctx sdk.Context) { + store := ctx.KVStore(k.storeKey) + + epochNumberBytes, err := sdk.NewUint(0).Marshal() + if err != nil { + panic(sdkerrors.Wrap(types.ErrMarshal, err.Error())) + } + + store.Set(types.EpochNumberKey, epochNumberBytes) +} + +// GetEpochNumber fetches epoch number +func (k Keeper) GetEpochNumber(ctx sdk.Context) sdk.Uint { + store := ctx.KVStore(k.storeKey) + + bz := store.Get(types.EpochNumberKey) + if bz == nil { + panic(types.ErrUnknownEpochNumber) + } + var epochNumber sdk.Uint + if err := epochNumber.Unmarshal(bz); err != nil { + panic(sdkerrors.Wrap(types.ErrUnmarshal, err.Error())) + } + + return epochNumber +} + +// setEpochNumber sets epoch number +func (k Keeper) setEpochNumber(ctx sdk.Context, epochNumber sdk.Uint) { + store := ctx.KVStore(k.storeKey) + + epochNumberBytes, err := epochNumber.Marshal() + if err != nil { + panic(sdkerrors.Wrap(types.ErrMarshal, err.Error())) + } + + store.Set(types.EpochNumberKey, epochNumberBytes) +} + +// IncEpochNumber adds epoch number by 1 +func (k Keeper) IncEpochNumber(ctx sdk.Context) sdk.Uint { + epochNumber := k.GetEpochNumber(ctx) + incrementedEpochNumber := epochNumber.AddUint64(1) + k.setEpochNumber(ctx, incrementedEpochNumber) + return incrementedEpochNumber +} + +// GetEpochBoundary gets the epoch boundary, i.e., the height of the block that ends this epoch +// example: in epoch 1, epoch interval is 5 blocks, boundary will be 1*5=5 +// 0 | 1 2 3 4 5 | 6 7 8 9 10 | +// 0 | 1 | 2 | +func (k Keeper) GetEpochBoundary(ctx sdk.Context) sdk.Uint { + epochNumber := k.GetEpochNumber(ctx) + // epoch number is 0 at the 0-th block, i.e., genesis + if epochNumber.IsZero() { + return sdk.NewUint(0) + } + // case when epoch number > 0 + epochInterval := sdk.NewUint(k.GetParams(ctx).EpochInterval) + return epochNumber.Mul(epochInterval) +} diff --git a/x/epoching/keeper/grpc_query_test.go b/x/epoching/keeper/grpc_query_test.go index 8ab55dacf..03f310dfe 100644 --- a/x/epoching/keeper/grpc_query_test.go +++ b/x/epoching/keeper/grpc_query_test.go @@ -1,22 +1,266 @@ package keeper_test import ( + "fmt" + "math/rand" "testing" - testkeeper "github.com/babylonchain/babylon/testutil/keeper" - sdk "github.com/cosmos/cosmos-sdk/types" - "github.com/babylonchain/babylon/x/epoching/types" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/query" "github.com/stretchr/testify/require" ) -func TestParamsQuery(t *testing.T) { - keeper, ctx := testkeeper.EpochingKeeper(t) +func (suite *KeeperTestSuite) TestParamsQuery() { + ctx, queryClient := suite.ctx, suite.queryClient + req := types.QueryParamsRequest{} + + testCases := []struct { + msg string + params types.Params + }{ + { + "default params", + types.DefaultParams(), + }, + } + + for _, tc := range testCases { + suite.Run(fmt.Sprintf("Case %s", tc.msg), func() { + wctx := sdk.WrapSDKContext(ctx) + resp, err := queryClient.Params(wctx, &req) + suite.NoError(err) + suite.Equal(&types.QueryParamsResponse{Params: tc.params}, resp) + }) + } +} + +// FuzzParamsQuery fuzzes queryClient.Params +// 1. Generate random param +// 2. When EpochInterval is 0, ensure `Validate` returns an error +// 3. Randomly set the param via query and check if the param has been updated +func FuzzParamsQuery(f *testing.F) { + f.Add(uint64(11111), int64(23)) + f.Add(uint64(22222), int64(330)) + f.Add(uint64(22222), int64(101)) + + f.Fuzz(func(t *testing.T, epochInterval uint64, seed int64) { + rand.Seed(seed) + + // params generated by fuzzer + params := types.DefaultParams() + params.EpochInterval = epochInterval + + // test the case of EpochInterval == 0 + // after that, change EpochInterval to a random non-zero value + if epochInterval == 0 { + params.EpochInterval = 0 + // validation should not pass with zero EpochInterval + require.Error(t, params.Validate()) + params.EpochInterval = uint64(rand.Int()) + } + + _, ctx, keeper, _, queryClient := setupTestKeeper() + wctx := sdk.WrapSDKContext(ctx) + // if setParamsFlag == 0, set params + setParamsFlag := rand.Intn(2) + if setParamsFlag == 0 { + keeper.SetParams(ctx, params) + } + req := types.QueryParamsRequest{} + resp, err := queryClient.Params(wctx, &req) + require.NoError(t, err) + // if setParamsFlag == 0, resp.Params should be changed, otherwise default + if setParamsFlag == 0 { + require.Equal(t, params, resp.Params) + } else { + require.Equal(t, types.DefaultParams(), resp.Params) + } + }) +} + +func (suite *KeeperTestSuite) TestCurrentEpoch() { + ctx, queryClient := suite.ctx, suite.queryClient + req := types.QueryCurrentEpochRequest{} + + testCases := []struct { + msg string + malleate func() + epochNumber sdk.Uint + epochBoundary sdk.Uint + }{ + { + "epoch 0", + func() {}, + sdk.NewUint(0), + sdk.NewUint(0), + }, + { + "epoch 1", + func() { + suite.keeper.IncEpochNumber(suite.ctx) + }, + sdk.NewUint(1), + sdk.NewUint(suite.keeper.GetParams(suite.ctx).EpochInterval * 1), + }, + { + "epoch 2", + func() { + suite.keeper.IncEpochNumber(suite.ctx) + }, + sdk.NewUint(2), + sdk.NewUint(suite.keeper.GetParams(suite.ctx).EpochInterval * 2), + }, + { + "reset to epoch 0", + func() { + suite.keeper.InitEpochNumber(suite.ctx) + }, + sdk.NewUint(0), + sdk.NewUint(0), + }, + } + + for _, tc := range testCases { + suite.Run(fmt.Sprintf("Case %s", tc.msg), func() { + tc.malleate() + wctx := sdk.WrapSDKContext(ctx) + resp, err := queryClient.CurrentEpoch(wctx, &req) + suite.NoError(err) + suite.Equal(tc.epochNumber.Uint64(), resp.CurrentEpoch) + suite.Equal(tc.epochBoundary.Uint64(), resp.EpochBoundary) + }) + } +} + +// FuzzCurrentEpoch fuzzes queryClient.CurrentEpoch +// 1. generate a random number of epochs to increment +// 2. query the current epoch and boundary +// 3. compare them with the correctly calculated ones +func FuzzCurrentEpoch(f *testing.F) { + f.Add(uint64(1111)) + f.Add(uint64(2222)) + f.Add(uint64(3333)) + + f.Fuzz(func(t *testing.T, increment uint64) { + _, ctx, keeper, _, queryClient := setupTestKeeper() + wctx := sdk.WrapSDKContext(ctx) + for i := uint64(0); i < increment; i++ { + keeper.IncEpochNumber(ctx) + } + req := types.QueryCurrentEpochRequest{} + resp, err := queryClient.CurrentEpoch(wctx, &req) + require.NoError(t, err) + require.Equal(t, increment, resp.CurrentEpoch) + require.Equal(t, increment*keeper.GetParams(ctx).EpochInterval, resp.EpochBoundary) + }) +} + +func (suite *KeeperTestSuite) TestEpochMsgs() { + ctx, queryClient := suite.ctx, suite.queryClient wctx := sdk.WrapSDKContext(ctx) - params := types.DefaultParams() - keeper.SetParams(ctx, params) + req := &types.QueryEpochMsgsRequest{ + Pagination: &query.PageRequest{ + Limit: 100, + }, + } + + testCases := []struct { + msg string + malleate func() + epochMsgs []*types.QueuedMessage + }{ + { + "empty epoch msgs", + func() {}, + []*types.QueuedMessage{}, + }, + { + "newly inserted epoch msg", + func() { + msg := types.QueuedMessage{ + TxId: []byte{0x01}, + } + suite.keeper.EnqueueMsg(suite.ctx, msg) + }, + []*types.QueuedMessage{ + {TxId: []byte{0x01}}, + }, + }, + { + "newly inserted epoch msg", + func() { + msg := types.QueuedMessage{ + TxId: []byte{0x02}, + } + suite.keeper.EnqueueMsg(suite.ctx, msg) + }, + []*types.QueuedMessage{ + {TxId: []byte{0x01}}, + {TxId: []byte{0x02}}, + }, + }, + { + "cleared epoch msg", + func() { + suite.keeper.ClearEpochMsgs(suite.ctx) + }, + []*types.QueuedMessage{}, + }, + } + + for _, tc := range testCases { + suite.Run(fmt.Sprintf("Case %s", tc.msg), func() { + tc.malleate() + resp, err := queryClient.EpochMsgs(wctx, req) + suite.NoError(err) + suite.Equal(len(tc.epochMsgs), len(resp.Msgs)) + suite.Equal(uint64(len(tc.epochMsgs)), suite.keeper.GetQueueLength(suite.ctx).Uint64()) + for idx := range tc.epochMsgs { + suite.Equal(tc.epochMsgs[idx].MsgId, resp.Msgs[idx].MsgId) + suite.Equal(tc.epochMsgs[idx].TxId, resp.Msgs[idx].TxId) + } + }) + } +} + +// FuzzEpochMsgs fuzzes queryClient.EpochMsgs +// 1. randomly generate msgs and limit in pagination +// 2. check the returned msg was previously enqueued +// NOTE: Msgs in QueryEpochMsgsResponse are out-of-roder +func FuzzEpochMsgs(f *testing.F) { + f.Add(int64(12)) + f.Add(int64(44)) + f.Add(int64(422)) + f.Add(int64(4222)) + + f.Fuzz(func(t *testing.T, seed int64) { + rand.Seed(seed) + numMsgs := uint64(rand.Int() % 100) + limit := uint64(rand.Int() % 100) + + txidsMap := map[string]bool{} + _, ctx, keeper, _, queryClient := setupTestKeeper() + wctx := sdk.WrapSDKContext(ctx) + // enque a random number of msgs with random txids + for i := uint64(0); i < numMsgs; i++ { + txid := genRandomByteSlice(32) + txidsMap[string(txid)] = true + keeper.EnqueueMsg(ctx, types.QueuedMessage{TxId: txid}) + } + // get epoch msgs + req := types.QueryEpochMsgsRequest{ + Pagination: &query.PageRequest{ + Limit: limit, + }, + } + resp, err := queryClient.EpochMsgs(wctx, &req) + require.NoError(t, err) - response, err := keeper.Params(wctx, &types.QueryParamsRequest{}) - require.NoError(t, err) - require.Equal(t, &types.QueryParamsResponse{Params: params}, response) + require.Equal(t, min(uint64(len(txidsMap)), limit), uint64(len(resp.Msgs))) + for idx := range resp.Msgs { + _, ok := txidsMap[string(resp.Msgs[idx].TxId)] + require.True(t, ok) + } + }) } diff --git a/x/epoching/keeper/hooks.go b/x/epoching/keeper/hooks.go index b6de4e77b..5e6958326 100644 --- a/x/epoching/keeper/hooks.go +++ b/x/epoching/keeper/hooks.go @@ -1,13 +1,25 @@ package keeper import ( + "fmt" + "github.com/babylonchain/babylon/x/epoching/types" sdk "github.com/cosmos/cosmos-sdk/types" + stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" ) -// Implements EpochingHooks interface +// Wrapper struct +type Hooks struct { + k Keeper +} + +// Implements StakingHooks/EpochingHooks interfaces +var _ stakingtypes.StakingHooks = Hooks{} var _ types.EpochingHooks = Keeper{} +// Create new distribution hooks +func (k Keeper) Hooks() Hooks { return Hooks{k} } + // AfterEpochBegins - call hook if registered func (k Keeper) AfterEpochBegins(ctx sdk.Context, epoch sdk.Uint) { if k.hooks != nil { @@ -21,3 +33,71 @@ func (k Keeper) AfterEpochEnds(ctx sdk.Context, epoch sdk.Uint) { k.hooks.AfterEpochEnds(ctx, epoch) } } + +// BeforeSlashThreshold triggers the BeforeSlashThreshold hook for other modules that register this hook +func (k Keeper) BeforeSlashThreshold(ctx sdk.Context, valAddrs []sdk.ValAddress) { + if k.hooks != nil { + k.hooks.BeforeSlashThreshold(ctx, valAddrs) + } +} + +// BeforeValidatorSlashed records the slash event +func (h Hooks) BeforeValidatorSlashed(ctx sdk.Context, valAddr sdk.ValAddress, fraction sdk.Dec) { + thresholds := []float64{float64(1) / float64(3), float64(2) / float64(3)} + + epochNumber := h.k.GetEpochNumber(ctx) + totalVotingPower := h.k.GetTotalVotingPower(ctx, epochNumber) + validatorSet := h.k.GetValidatorSet(ctx, epochNumber) + + // calculate total slashed voting power + slashedVotingPower := h.k.GetSlashedVotingPower(ctx, epochNumber) + // voting power of this validator + thisVotingPower, ok := validatorSet[valAddr.String()] + if !ok { + // It's possible that the most powerful validator outside the validator set enrols to the validator after an existing validator is slashed. + // Consequently, here we cannot find this validator in the validatorSet map. + // As we consider the validator set in the epoch beginning to be the validator set throughout this epoch, we consider this new validator in the edge to have no voting power and return directly here. + return + } + + for _, threshold := range thresholds { + // if a certain threshold voting power is slashed in a single epoch, emit event and trigger hook + if float64(slashedVotingPower) < float64(totalVotingPower)*threshold && float64(totalVotingPower)*threshold <= float64(slashedVotingPower+thisVotingPower) { + // get slashed validators + slashedVals := h.k.GetSlashedValidators(ctx, epochNumber) + slashedVals = append(slashedVals, valAddr) + // emit event + ctx.EventManager().EmitEvents(sdk.Events{ + sdk.NewEvent( + types.EventTypeSlashThreshold, + sdk.NewAttribute(types.AttributeKeySlashedVotingPower, fmt.Sprintf("%d", slashedVotingPower)), + sdk.NewAttribute(types.AttributeKeyTotalVotingPower, fmt.Sprintf("%d", slashedVotingPower)), + sdk.NewAttribute(types.AttributeKeySlashedValidators, fmt.Sprintf("%v", slashedVals)), + ), + }) + // trigger hook + h.k.BeforeSlashThreshold(ctx, slashedVals) + } + } + + // add the validator address to the set + h.k.AddSlashedValidator(ctx, valAddr) +} + +// Other staking hooks that are not used in the epoching module +func (h Hooks) AfterValidatorCreated(ctx sdk.Context, valAddr sdk.ValAddress) {} +func (h Hooks) BeforeValidatorModified(ctx sdk.Context, valAddr sdk.ValAddress) {} +func (h Hooks) AfterValidatorRemoved(ctx sdk.Context, consAddr sdk.ConsAddress, valAddr sdk.ValAddress) { +} +func (h Hooks) AfterValidatorBonded(ctx sdk.Context, consAddr sdk.ConsAddress, valAddr sdk.ValAddress) { +} +func (h Hooks) AfterValidatorBeginUnbonding(ctx sdk.Context, consAddr sdk.ConsAddress, valAddr sdk.ValAddress) { +} +func (h Hooks) BeforeDelegationCreated(ctx sdk.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress) { +} +func (h Hooks) BeforeDelegationSharesModified(ctx sdk.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress) { +} +func (h Hooks) BeforeDelegationRemoved(ctx sdk.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress) { +} +func (h Hooks) AfterDelegationModified(ctx sdk.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress) { +} diff --git a/x/epoching/keeper/keeper.go b/x/epoching/keeper/keeper.go index 3acd2644a..0957db756 100644 --- a/x/epoching/keeper/keeper.go +++ b/x/epoching/keeper/keeper.go @@ -7,7 +7,6 @@ import ( "github.com/cosmos/cosmos-sdk/baseapp" "github.com/cosmos/cosmos-sdk/codec" sdk "github.com/cosmos/cosmos-sdk/types" - sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" paramtypes "github.com/cosmos/cosmos-sdk/x/params/types" "github.com/tendermint/tendermint/libs/log" ) @@ -66,172 +65,3 @@ func (k *Keeper) SetMsgServiceRouter(router *baseapp.MsgServiceRouter) *Keeper { k.router = router return k } - -// GetEpochNumber fetches epoch number -func (k Keeper) GetEpochNumber(ctx sdk.Context) sdk.Uint { - store := ctx.KVStore(k.storeKey) - - bz := store.Get(types.EpochNumberKey) - if bz == nil { - panic(types.ErrUnknownEpochNumber) - } - var epochNumber sdk.Uint - if err := epochNumber.Unmarshal(bz); err != nil { - panic(err) - } - - return epochNumber -} - -// SetEpochNumber sets epoch number -func (k Keeper) SetEpochNumber(ctx sdk.Context, epochNumber sdk.Uint) { - store := ctx.KVStore(k.storeKey) - - epochNumberBytes, err := epochNumber.Marshal() - if err != nil { - panic(err) - } - - store.Set(types.EpochNumberKey, epochNumberBytes) -} - -// IncEpochNumber adds epoch number by 1 -func (k Keeper) IncEpochNumber(ctx sdk.Context) sdk.Uint { - epochNumber := k.GetEpochNumber(ctx) - incrementedEpochNumber := epochNumber.AddUint64(1) - k.SetEpochNumber(ctx, incrementedEpochNumber) - return incrementedEpochNumber -} - -// GetEpochBoundary gets the epoch boundary, i.e., the height of the block that ends this epoch -// example: in epoch 1, epoch interval is 5 blocks, boundary will be 1*5=5 -// 0 | 1 2 3 4 5 | 6 7 8 9 10 | -// 0 | 1 | 2 | -func (k Keeper) GetEpochBoundary(ctx sdk.Context) sdk.Uint { - epochNumber := k.GetEpochNumber(ctx) - // epoch number is 0 at the 0-th block, i.e., genesis - if epochNumber.IsZero() { - return sdk.NewUint(0) - } - // case when epoch number > 0 - epochInterval := sdk.NewUint(k.GetParams(ctx).EpochInterval) - return epochNumber.Mul(epochInterval) -} - -// GetQueueLength fetches the number of queued messages -func (k Keeper) GetQueueLength(ctx sdk.Context) sdk.Uint { - store := ctx.KVStore(k.storeKey) - - // get queue len in bytes from DB - bz := store.Get(types.QueueLengthKey) - if bz == nil { - panic(types.ErrUnknownQueueLen) - } - // unmarshal - var queueLen sdk.Uint - if err := queueLen.Unmarshal(bz); err != nil { - panic(err) - } - - return queueLen -} - -// SetQueueLength sets the msg queue length -func (k Keeper) SetQueueLength(ctx sdk.Context, queueLen sdk.Uint) { - store := ctx.KVStore(k.storeKey) - - queueLenBytes, err := queueLen.Marshal() - if err != nil { - panic(err) - } - - store.Set(types.QueueLengthKey, queueLenBytes) -} - -// incQueueLength adds the queue length by 1 -func (k Keeper) incQueueLength(ctx sdk.Context) { - queueLen := k.GetQueueLength(ctx) - incrementedQueueLen := queueLen.AddUint64(1) - k.SetQueueLength(ctx, incrementedQueueLen) -} - -// EnqueueMsg enqueues a message to the queue of the current epoch -func (k Keeper) EnqueueMsg(ctx sdk.Context, msg types.QueuedMessage) { - store := ctx.KVStore(k.storeKey) - - // insert KV pair, where - // - key: QueuedMsgKey || queueLenBytes - // - value: msgBytes - queueLen := k.GetQueueLength(ctx) - queueLenBytes, err := queueLen.Marshal() - if err != nil { - panic(err) - } - msgBytes, err := k.cdc.Marshal(&msg) - if err != nil { - panic(err) - } - store.Set(append(types.QueuedMsgKey, queueLenBytes...), msgBytes) - - // increment queue length - k.incQueueLength(ctx) -} - -// GetEpochMsgs returns the set of messages queued in the current epoch -func (k Keeper) GetEpochMsgs(ctx sdk.Context) []*types.QueuedMessage { - queuedMsgs := []*types.QueuedMessage{} - store := ctx.KVStore(k.storeKey) - - // add each queued msg to queuedMsgs - iterator := sdk.KVStorePrefixIterator(store, types.QueuedMsgKey) - defer iterator.Close() - for ; iterator.Valid(); iterator.Next() { - queuedMsgBytes := iterator.Value() - var queuedMsg types.QueuedMessage - if err := k.cdc.Unmarshal(queuedMsgBytes, &queuedMsg); err != nil { - panic(err) - } - queuedMsgs = append(queuedMsgs, &queuedMsg) - } - - return queuedMsgs -} - -// ClearEpochMsgs removes all messages in the queue -func (k Keeper) ClearEpochMsgs(ctx sdk.Context) { - store := ctx.KVStore(k.storeKey) - - // remove all epoch msgs - iterator := sdk.KVStorePrefixIterator(ctx.KVStore(k.storeKey), types.QueuedMsgKey) - defer iterator.Close() - for ; iterator.Valid(); iterator.Next() { - key := iterator.Key() - store.Delete(key) - } - - // set queue len to zero - k.SetQueueLength(ctx, sdk.NewUint(0)) -} - -// HandleQueuedMsg unwraps a QueuedMessage and forwards it to the staking module -func (k Keeper) HandleQueuedMsg(ctx sdk.Context, msg *types.QueuedMessage) (*sdk.Result, error) { - var unwrappedMsgWithType sdk.Msg - // TODO: after we bump to Cosmos SDK v0.46, add MsgCancelUnbondingDelegation - switch unwrappedMsg := msg.Msg.(type) { - case *types.QueuedMessage_MsgCreateValidator: - unwrappedMsgWithType = unwrappedMsg.MsgCreateValidator - case *types.QueuedMessage_MsgDelegate: - unwrappedMsgWithType = unwrappedMsg.MsgDelegate - case *types.QueuedMessage_MsgUndelegate: - unwrappedMsgWithType = unwrappedMsg.MsgUndelegate - case *types.QueuedMessage_MsgBeginRedelegate: - unwrappedMsgWithType = unwrappedMsg.MsgBeginRedelegate - default: - panic(sdkerrors.Wrap(types.ErrInvalidQueuedMessageType, msg.String())) - } - - // get the handler function from router - handler := k.router.Handler(unwrappedMsgWithType) - // handle the unwrapped message - return handler(ctx, unwrappedMsgWithType) -} diff --git a/x/epoching/keeper/keeper_test.go b/x/epoching/keeper/keeper_test.go index 942926490..ead1e28b6 100644 --- a/x/epoching/keeper/keeper_test.go +++ b/x/epoching/keeper/keeper_test.go @@ -1 +1,81 @@ package keeper_test + +import ( + "crypto/rand" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/cosmos/cosmos-sdk/baseapp" + tmproto "github.com/tendermint/tendermint/proto/tendermint/types" + + "github.com/babylonchain/babylon/app" + "github.com/babylonchain/babylon/x/epoching/keeper" + "github.com/babylonchain/babylon/x/epoching/types" + sdk "github.com/cosmos/cosmos-sdk/types" +) + +type KeeperTestSuite struct { + suite.Suite + + app *app.BabylonApp + ctx sdk.Context + keeper *keeper.Keeper + msgSrvr types.MsgServer + queryClient types.QueryClient +} + +// setupTestKeeper creates a new server +func setupTestKeeper() (*app.BabylonApp, sdk.Context, *keeper.Keeper, types.MsgServer, types.QueryClient) { + app := app.Setup(false) + ctx := app.BaseApp.NewContext(false, tmproto.Header{}) + + epochingKeeper := app.EpochingKeeper + querier := keeper.Querier{Keeper: epochingKeeper} + queryHelper := baseapp.NewQueryServerTestHelper(ctx, app.InterfaceRegistry()) + types.RegisterQueryServer(queryHelper, querier) + queryClient := types.NewQueryClient(queryHelper) + + msgSrvr := keeper.NewMsgServerImpl(epochingKeeper) + + return app, ctx, &epochingKeeper, msgSrvr, queryClient +} + +func (suite *KeeperTestSuite) SetupTest() { + suite.app, suite.ctx, suite.keeper, suite.msgSrvr, suite.queryClient = setupTestKeeper() +} + +func TestParams(t *testing.T) { + app := app.Setup(false) + ctx := app.BaseApp.NewContext(false, tmproto.Header{}) + + expParams := types.DefaultParams() + + //check that the empty keeper loads the default + resParams := app.EpochingKeeper.GetParams(ctx) + require.True(t, expParams.Equal(resParams)) + + //modify a params, save, and retrieve + expParams.EpochInterval = 777 + app.EpochingKeeper.SetParams(ctx, expParams) + resParams = app.EpochingKeeper.GetParams(ctx) + require.True(t, expParams.Equal(resParams)) +} + +func TestKeeperTestSuite(t *testing.T) { + suite.Run(t, new(KeeperTestSuite)) +} + +func genRandomByteSlice(length uint64) []byte { + r := make([]byte, length) + rand.Read(r) + return r +} + +func min(a, b uint64) uint64 { + if a < b { + return a + } + return b +} diff --git a/x/epoching/keeper/msg_server.go b/x/epoching/keeper/msg_server.go index a4ea4c823..5eb809f62 100644 --- a/x/epoching/keeper/msg_server.go +++ b/x/epoching/keeper/msg_server.go @@ -24,8 +24,6 @@ var _ types.MsgServer = msgServer{} // WrappedDelegate handles the MsgWrappedDelegate request func (k msgServer) WrappedDelegate(goCtx context.Context, msg *types.MsgWrappedDelegate) (*types.MsgWrappedDelegateResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) - - // get msg in bytes msgBytes, err := k.cdc.Marshal(msg) if err != nil { return nil, err @@ -40,8 +38,20 @@ func (k msgServer) WrappedDelegate(goCtx context.Context, msg *types.MsgWrappedD }, } - // enqueue msg k.EnqueueMsg(ctx, queuedMsg) + ctx.EventManager().EmitEvents(sdk.Events{ + sdk.NewEvent( + types.EventTypeWrappedDelegate, + sdk.NewAttribute(types.AttributeKeyValidator, msg.Msg.ValidatorAddress), + sdk.NewAttribute(sdk.AttributeKeyAmount, msg.Msg.Amount.String()), + sdk.NewAttribute(types.AttributeKeyEpochBoundary, k.GetEpochBoundary(ctx).String()), + ), + sdk.NewEvent( + sdk.EventTypeMessage, + sdk.NewAttribute(sdk.AttributeKeyModule, types.AttributeValueCategory), + sdk.NewAttribute(sdk.AttributeKeySender, msg.Msg.DelegatorAddress), + ), + }) return &types.MsgWrappedDelegateResponse{}, nil } @@ -49,8 +59,6 @@ func (k msgServer) WrappedDelegate(goCtx context.Context, msg *types.MsgWrappedD // WrappedUndelegate handles the MsgWrappedUndelegate request func (k msgServer) WrappedUndelegate(goCtx context.Context, msg *types.MsgWrappedUndelegate) (*types.MsgWrappedUndelegateResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) - - // get msg in bytes msgBytes, err := k.cdc.Marshal(msg) if err != nil { return nil, err @@ -65,8 +73,20 @@ func (k msgServer) WrappedUndelegate(goCtx context.Context, msg *types.MsgWrappe }, } - // enqueue msg k.EnqueueMsg(ctx, queuedMsg) + ctx.EventManager().EmitEvents(sdk.Events{ + sdk.NewEvent( + types.EventTypeWrappedUndelegate, + sdk.NewAttribute(types.AttributeKeyValidator, msg.Msg.ValidatorAddress), + sdk.NewAttribute(sdk.AttributeKeyAmount, msg.Msg.Amount.String()), + sdk.NewAttribute(types.AttributeKeyEpochBoundary, k.GetEpochBoundary(ctx).String()), + ), + sdk.NewEvent( + sdk.EventTypeMessage, + sdk.NewAttribute(sdk.AttributeKeyModule, types.AttributeValueCategory), + sdk.NewAttribute(sdk.AttributeKeySender, msg.Msg.DelegatorAddress), + ), + }) return &types.MsgWrappedUndelegateResponse{}, nil } @@ -74,8 +94,6 @@ func (k msgServer) WrappedUndelegate(goCtx context.Context, msg *types.MsgWrappe // WrappedBeginRedelegate handles the MsgWrappedBeginRedelegate request func (k msgServer) WrappedBeginRedelegate(goCtx context.Context, msg *types.MsgWrappedBeginRedelegate) (*types.MsgWrappedBeginRedelegateResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) - - // get msg in bytes msgBytes, err := k.cdc.Marshal(msg) if err != nil { return nil, err @@ -92,6 +110,21 @@ func (k msgServer) WrappedBeginRedelegate(goCtx context.Context, msg *types.MsgW // enqueue msg k.EnqueueMsg(ctx, queuedMsg) + // emit event + ctx.EventManager().EmitEvents(sdk.Events{ + sdk.NewEvent( + types.EventTypeWrappedBeginRedelegate, + sdk.NewAttribute(types.AttributeKeySrcValidator, msg.Msg.ValidatorSrcAddress), + sdk.NewAttribute(types.AttributeKeyDstValidator, msg.Msg.ValidatorDstAddress), + sdk.NewAttribute(sdk.AttributeKeyAmount, msg.Msg.Amount.String()), + sdk.NewAttribute(types.AttributeKeyEpochBoundary, k.GetEpochBoundary(ctx).String()), + ), + sdk.NewEvent( + sdk.EventTypeMessage, + sdk.NewAttribute(sdk.AttributeKeyModule, types.AttributeValueCategory), + sdk.NewAttribute(sdk.AttributeKeySender, msg.Msg.DelegatorAddress), + ), + }) return &types.MsgWrappedBeginRedelegateResponse{}, nil } diff --git a/x/epoching/keeper/msg_server_test.go b/x/epoching/keeper/msg_server_test.go index 63e6f012c..0fa2819dd 100644 --- a/x/epoching/keeper/msg_server_test.go +++ b/x/epoching/keeper/msg_server_test.go @@ -1,16 +1,112 @@ package keeper_test import ( - "context" - "testing" - - keepertest "github.com/babylonchain/babylon/testutil/keeper" - "github.com/babylonchain/babylon/x/epoching/keeper" "github.com/babylonchain/babylon/x/epoching/types" sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/query" + stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" ) -func setupMsgServer(t testing.TB) (types.MsgServer, context.Context) { - k, ctx := keepertest.EpochingKeeper(t) - return keeper.NewMsgServerImpl(*k), sdk.WrapSDKContext(ctx) +func (suite *KeeperTestSuite) TestMsgWrappedDelegate() { + testCases := []struct { + name string + req *stakingtypes.MsgDelegate + expectErr bool + }{ + { + "empty wrapped msg", + &stakingtypes.MsgDelegate{}, + false, + }, + } + for _, tc := range testCases { + wctx := sdk.WrapSDKContext(suite.ctx) + suite.Run(tc.name, func() { + wrappedMsg := types.NewMsgWrappedDelegate(tc.req) + _, err := suite.msgSrvr.WrappedDelegate(wctx, wrappedMsg) + suite.Require().NoError(err) + + resp, err := suite.queryClient.EpochMsgs(wctx, &types.QueryEpochMsgsRequest{ + Pagination: &query.PageRequest{}, + }) + suite.Require().NoError(err) + suite.Require().Equal(1, len(resp.Msgs)) + + if tc.expectErr { + suite.Require().Error(err) + } else { + suite.Require().NoError(err) + } + }) + } +} + +func (suite *KeeperTestSuite) TestMsgWrappedUndelegate() { + testCases := []struct { + name string + req *stakingtypes.MsgUndelegate + expectErr bool + }{ + { + "empty wrapped msg", + &stakingtypes.MsgUndelegate{}, + false, + }, + } + for _, tc := range testCases { + wctx := sdk.WrapSDKContext(suite.ctx) + suite.Run(tc.name, func() { + wrappedMsg := types.NewMsgWrappedUndelegate(tc.req) + _, err := suite.msgSrvr.WrappedUndelegate(wctx, wrappedMsg) + suite.Require().NoError(err) + + resp, err := suite.queryClient.EpochMsgs(wctx, &types.QueryEpochMsgsRequest{ + Pagination: &query.PageRequest{}, + }) + suite.Require().NoError(err) + suite.Require().Equal(1, len(resp.Msgs)) + + if tc.expectErr { + suite.Require().Error(err) + } else { + suite.Require().NoError(err) + } + }) + } +} + +func (suite *KeeperTestSuite) TestMsgWrappedBeginRedelegate() { + testCases := []struct { + name string + req *stakingtypes.MsgBeginRedelegate + expectErr bool + }{ + { + "empty wrapped msg", + &stakingtypes.MsgBeginRedelegate{}, + false, + }, + } + for _, tc := range testCases { + wctx := sdk.WrapSDKContext(suite.ctx) + wrappedMsg := types.NewMsgWrappedBeginRedelegate(tc.req) + + _, err := suite.msgSrvr.WrappedBeginRedelegate(wctx, wrappedMsg) + suite.Require().NoError(err) + + resp, err := suite.queryClient.EpochMsgs(wctx, &types.QueryEpochMsgsRequest{ + Pagination: &query.PageRequest{}, + }) + suite.Require().NoError(err) + suite.Require().Equal(1, len(resp.Msgs)) + + suite.Run(tc.name, func() { + _, err := suite.msgSrvr.WrappedBeginRedelegate(wctx, wrappedMsg) + if tc.expectErr { + suite.Require().Error(err) + } else { + suite.Require().NoError(err) + } + }) + } } diff --git a/x/epoching/types/errors.go b/x/epoching/types/errors.go index 872fe36aa..b2f13fabf 100644 --- a/x/epoching/types/errors.go +++ b/x/epoching/types/errors.go @@ -8,8 +8,14 @@ import ( // x/epoching module sentinel errors var ( - ErrUnwrappedMsgType = sdkerrors.Register(ModuleName, 1, "invalid message type in {MsgCreateValidator, MsgDelegate, MsgUndelegate, MsgBeginRedelegate} messages. use wrapped versions instead") - ErrInvalidQueuedMessageType = sdkerrors.Register(ModuleName, 2, "invalid message type of a QueuedMessage") - ErrUnknownEpochNumber = sdkerrors.Register(ModuleName, 3, "the epoch number is not known in DB") - ErrUnknownQueueLen = sdkerrors.Register(ModuleName, 4, "the msg queue length is not known in DB") + ErrUnwrappedMsgType = sdkerrors.Register(ModuleName, 1, "invalid message type in {MsgCreateValidator, MsgDelegate, MsgUndelegate, MsgBeginRedelegate} messages. use wrapped versions instead") + ErrInvalidQueuedMessageType = sdkerrors.Register(ModuleName, 2, "invalid message type of a QueuedMessage") + ErrUnknownEpochNumber = sdkerrors.Register(ModuleName, 3, "the epoch number is not known in DB") + ErrUnknownQueueLen = sdkerrors.Register(ModuleName, 4, "the msg queue length is not known in DB") + ErrUnknownSlashedVotingPower = sdkerrors.Register(ModuleName, 5, "the slashed voting power is not known in DB. Maybe the epoch has been checkpointed?") + ErrUnknownValidator = sdkerrors.Register(ModuleName, 6, "the slashed validator is not in the validator set.") + ErrUnknownTotalVotingPower = sdkerrors.Register(ModuleName, 7, "the total voting power is not known in DB.") + ErrMarshal = sdkerrors.Register(ModuleName, 8, "marshal error.") + ErrUnmarshal = sdkerrors.Register(ModuleName, 9, "unmarshal error.") + ErrNoWrappedMsg = sdkerrors.Register(ModuleName, 10, "the wrapped msg contains no msg inside.") ) diff --git a/x/epoching/types/events.go b/x/epoching/types/events.go index 4daad53fe..2dcd67f99 100644 --- a/x/epoching/types/events.go +++ b/x/epoching/types/events.go @@ -2,15 +2,27 @@ package types // epoching module event types const ( - EventTypeBeginEpoch = "begin_epoch" - EventTypeEndEpoch = "end_epoch" - EventTypeHandleQueuedMsg = "handle_queue_msg" - EventTypeHandleQueuedMsgFailed = "handle_queue_msg_failed" + EventTypeBeginEpoch = "begin_epoch" + EventTypeEndEpoch = "end_epoch" + EventTypeHandleQueuedMsg = "handle_queue_msg" + EventTypeHandleQueuedMsgFailed = "handle_queue_msg_failed" + EventTypeSlashThreshold = "slash_threshold" + EventTypeWrappedDelegate = "wrapped_delegate" + EventTypeWrappedUndelegate = "wrapped_undelegate" + EventTypeWrappedBeginRedelegate = "wrapped_begin_redelegate" - AttributeKeyEpoch = "epoch" - AttributeKeyOriginalEventType = "original_event_type" - AttributeKeyTxId = "tx_id" - AttributeKeyMsgId = "msg_id" - AttributeKeyErrorMsg = "error_msg" - AttributeValueCategory = ModuleName + AttributeKeyEpoch = "epoch" + AttributeKeySlashedVotingPower = "slashed_voting_power" + AttributeKeyTotalVotingPower = "total_voting_power" + AttributeKeySlashedValidators = "slashed_validators" + AttributeKeyOriginalEventType = "original_event_type" + AttributeKeyTxId = "tx_id" + AttributeKeyMsgId = "msg_id" + AttributeKeyErrorMsg = "error_msg" + AttributeKeyValidator = "validator" + AttributeKeySrcValidator = "source_validator" + AttributeKeyDstValidator = "destination_validator" + AttributeKeyEpochBoundary = "epoch_boundary" + + AttributeValueCategory = ModuleName ) diff --git a/x/epoching/types/expected_keepers.go b/x/epoching/types/expected_keepers.go index 5298ce1cc..15d3e50f2 100644 --- a/x/epoching/types/expected_keepers.go +++ b/x/epoching/types/expected_keepers.go @@ -25,22 +25,30 @@ type BankKeeper interface { // StakingKeeper defines the staking module interface contract needed by the // epoching module. type StakingKeeper interface { + GetParams(ctx sdk.Context) stakingtypes.Params UnbondAllMatureValidators(ctx sdk.Context) DequeueAllMatureUBDQueue(ctx sdk.Context, currTime time.Time) (matureUnbonds []stakingtypes.DVPair) CompleteUnbonding(ctx sdk.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress) (sdk.Coins, error) DequeueAllMatureRedelegationQueue(ctx sdk.Context, currTime time.Time) (matureRedelegations []stakingtypes.DVVTriplet) CompleteRedelegation(ctx sdk.Context, delAddr sdk.AccAddress, valSrcAddr, valDstAddr sdk.ValAddress) (sdk.Coins, error) ApplyAndReturnValidatorSetUpdates(ctx sdk.Context) (updates []abci.ValidatorUpdate, err error) + IterateLastValidatorPowers(ctx sdk.Context, handler func(operator sdk.ValAddress, power int64) (stop bool)) } // Event Hooks -// These can be utilized to communicate between a staking keeper and another +// These can be utilized to communicate between an epoching keeper and another // keeper which must take particular actions when validators/delegators change // state. The second keeper must implement this interface, which then the -// staking keeper can call. +// epoching keeper can call. // EpochingHooks event hooks for epoching validator object (noalias) type EpochingHooks interface { - AfterEpochBegins(ctx sdk.Context, epoch sdk.Uint) // Must be called after an epoch begins - AfterEpochEnds(ctx sdk.Context, epoch sdk.Uint) // Must be called after an epoch ends + AfterEpochBegins(ctx sdk.Context, epoch sdk.Uint) // Must be called after an epoch begins + AfterEpochEnds(ctx sdk.Context, epoch sdk.Uint) // Must be called after an epoch ends + BeforeSlashThreshold(ctx sdk.Context, valAddrs []sdk.ValAddress) // Must be called before a certain threshold (1/3 or 2/3) of validators are slashed in a single epoch +} + +// StakingHooks event hooks for staking validator object (noalias) +type StakingHooks interface { + BeforeValidatorSlashed(ctx sdk.Context, valAddr sdk.ValAddress, fraction sdk.Dec) // Must be called right before a validator is slashed } diff --git a/x/epoching/types/genesis_test.go b/x/epoching/types/genesis_test.go index 58e5908b6..be937754f 100644 --- a/x/epoching/types/genesis_test.go +++ b/x/epoching/types/genesis_test.go @@ -3,22 +3,27 @@ package types_test import ( "testing" - keepertest "github.com/babylonchain/babylon/testutil/keeper" + "github.com/babylonchain/babylon/app" "github.com/babylonchain/babylon/testutil/nullify" "github.com/babylonchain/babylon/x/epoching" - "github.com/babylonchain/babylon/x/epoching/types" "github.com/stretchr/testify/require" + tmproto "github.com/tendermint/tendermint/proto/tendermint/types" ) func TestGenesis(t *testing.T) { + // This test requires setting up the staking module + // Otherwise the epoching module cannot initialise the genesis validator set + app := app.Setup(false) + ctx := app.BaseApp.NewContext(false, tmproto.Header{}) + keeper := app.EpochingKeeper + genesisState := types.GenesisState{ Params: types.DefaultParams(), } - k, ctx := keepertest.EpochingKeeper(t) - epoching.InitGenesis(ctx, *k, genesisState) - got := epoching.ExportGenesis(ctx, *k) + epoching.InitGenesis(ctx, keeper, genesisState) + got := epoching.ExportGenesis(ctx, keeper) require.NotNil(t, got) nullify.Fill(&genesisState) diff --git a/x/epoching/types/hooks.go b/x/epoching/types/hooks.go index a1147b96e..af9b3cdf0 100644 --- a/x/epoching/types/hooks.go +++ b/x/epoching/types/hooks.go @@ -24,3 +24,9 @@ func (h MultiEpochingHooks) AfterEpochEnds(ctx sdk.Context, epoch sdk.Uint) { h[i].AfterEpochEnds(ctx, epoch) } } + +func (h MultiEpochingHooks) BeforeSlashThreshold(ctx sdk.Context, valAddrs []sdk.ValAddress) { + for i := range h { + h[i].BeforeSlashThreshold(ctx, valAddrs) + } +} diff --git a/x/epoching/types/keys.go b/x/epoching/types/keys.go index 15be2cce1..de67275bb 100644 --- a/x/epoching/types/keys.go +++ b/x/epoching/types/keys.go @@ -18,9 +18,13 @@ const ( ) var ( - EpochNumberKey = []byte{0x11} // key prefix for the epoch number - QueueLengthKey = []byte{0x12} // key prefix for the queue length - QueuedMsgKey = []byte{0x13} // key prefix for a queued message + EpochNumberKey = []byte{0x11} // key prefix for the epoch number + QueueLengthKey = []byte{0x12} // key prefix for the queue length + QueuedMsgKey = []byte{0x13} // key prefix for a queued message + ValidatorSetKey = []byte{0x14} // key prefix for the validator set in a single epoch + VotingPowerKey = []byte{0x15} // key prefix for the total voting power of a validator set in a single epoch + SlashedVotingPowerKey = []byte{0x16} // key prefix for the total slashed voting power in a single epoch + SlashedValidatorSetKey = []byte{0x17} // key prefix for slashed validator set ) func KeyPrefix(p string) []byte { diff --git a/x/epoching/types/msg.go b/x/epoching/types/msg.go index 2aff671c0..abc9d3144 100644 --- a/x/epoching/types/msg.go +++ b/x/epoching/types/msg.go @@ -20,12 +20,10 @@ var ( ) // NewMsgWrappedDelegate creates a new MsgWrappedDelegate instance. -func NewMsgWrappedDelegate( - msg *stakingtypes.MsgDelegate, -) (*MsgWrappedDelegate, error) { +func NewMsgWrappedDelegate(msg *stakingtypes.MsgDelegate) *MsgWrappedDelegate { return &MsgWrappedDelegate{ Msg: msg, - }, nil + } } // Route implements the sdk.Msg interface. @@ -49,16 +47,17 @@ func (msg MsgWrappedDelegate) GetSignBytes() []byte { // ValidateBasic implements the sdk.Msg interface. func (msg MsgWrappedDelegate) ValidateBasic() error { + if msg.Msg == nil { + return ErrNoWrappedMsg + } return msg.Msg.ValidateBasic() } // NewMsgWrappedUndelegate creates a new MsgWrappedUndelegate instance. -func NewMsgWrappedUndelegate( - msg *stakingtypes.MsgUndelegate, -) (*MsgWrappedUndelegate, error) { +func NewMsgWrappedUndelegate(msg *stakingtypes.MsgUndelegate) *MsgWrappedUndelegate { return &MsgWrappedUndelegate{ Msg: msg, - }, nil + } } // Route implements the sdk.Msg interface. @@ -82,16 +81,17 @@ func (msg MsgWrappedUndelegate) GetSignBytes() []byte { // ValidateBasic implements the sdk.Msg interface. func (msg MsgWrappedUndelegate) ValidateBasic() error { + if msg.Msg == nil { + return ErrNoWrappedMsg + } return msg.Msg.ValidateBasic() } // NewMsgWrappedBeginRedelegate creates a new MsgWrappedBeginRedelegate instance. -func NewMsgWrappedBeginRedelegate( - msg *stakingtypes.MsgBeginRedelegate, -) (*MsgWrappedBeginRedelegate, error) { +func NewMsgWrappedBeginRedelegate(msg *stakingtypes.MsgBeginRedelegate) *MsgWrappedBeginRedelegate { return &MsgWrappedBeginRedelegate{ Msg: msg, - }, nil + } } // Route implements the sdk.Msg interface. @@ -115,5 +115,8 @@ func (msg MsgWrappedBeginRedelegate) GetSignBytes() []byte { // ValidateBasic implements the sdk.Msg interface. func (msg MsgWrappedBeginRedelegate) ValidateBasic() error { + if msg.Msg == nil { + return ErrNoWrappedMsg + } return msg.Msg.ValidateBasic() } diff --git a/x/epoching/types/msg_test.go b/x/epoching/types/msg_test.go new file mode 100644 index 000000000..6605c8930 --- /dev/null +++ b/x/epoching/types/msg_test.go @@ -0,0 +1,166 @@ +package types_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/babylonchain/babylon/x/epoching/types" + "github.com/cosmos/cosmos-sdk/codec" + codectypes "github.com/cosmos/cosmos-sdk/codec/types" + cryptocodec "github.com/cosmos/cosmos-sdk/crypto/codec" + "github.com/cosmos/cosmos-sdk/crypto/keys/ed25519" + cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" + sdk "github.com/cosmos/cosmos-sdk/types" + stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" +) + +// Most of the code below is adapted from https://github.com/cosmos/cosmos-sdk/blob/v0.45.5/x/staking/types/msg_test.go + +var ( + pk1 = ed25519.GenPrivKey().PubKey() + pk2 = ed25519.GenPrivKey().PubKey() + pk3 = ed25519.GenPrivKey().PubKey() + valAddr1 = sdk.ValAddress(pk1.Address()) + valAddr2 = sdk.ValAddress(pk2.Address()) + valAddr3 = sdk.ValAddress(pk3.Address()) + + emptyAddr sdk.ValAddress + + coinPos = sdk.NewInt64Coin(sdk.DefaultBondDenom, 1000) + coinZero = sdk.NewInt64Coin(sdk.DefaultBondDenom, 0) +) + +func TestMsgDecode(t *testing.T) { + registry := codectypes.NewInterfaceRegistry() + cryptocodec.RegisterInterfaces(registry) + types.RegisterInterfaces(registry) + cdc := codec.NewProtoCodec(registry) + + // pubkey serialisation/deserialisation + pk1bz, err := cdc.MarshalInterface(pk1) + require.NoError(t, err) + var pkUnmarshaled cryptotypes.PubKey + err = cdc.UnmarshalInterface(pk1bz, &pkUnmarshaled) + require.NoError(t, err) + require.True(t, pk1.Equals(pkUnmarshaled.(*ed25519.PubKey))) + + // create unwrapped msg + msgUnwrapped := stakingtypes.NewMsgDelegate(sdk.AccAddress(valAddr1), valAddr2, coinPos) + + // wrap and marshal msg + msg := types.NewMsgWrappedDelegate(msgUnwrapped) + msgSerialized, err := cdc.MarshalInterface(msg) + require.NoError(t, err) + + var msgUnmarshaled sdk.Msg + err = cdc.UnmarshalInterface(msgSerialized, &msgUnmarshaled) + require.NoError(t, err) + msg2, ok := msgUnmarshaled.(*types.MsgWrappedDelegate) + require.True(t, ok) + require.Equal(t, msg.Msg.Amount, msg2.Msg.Amount) + require.Equal(t, msg.Msg.DelegatorAddress, msg2.Msg.DelegatorAddress) + require.Equal(t, msg.Msg.ValidatorAddress, msg2.Msg.ValidatorAddress) +} + +// test ValidateBasic for MsgWrappedDelegate +func TestMsgWrappedDelegate(t *testing.T) { + tests := []struct { + name string + delegatorAddr sdk.AccAddress + validatorAddr sdk.ValAddress + bond sdk.Coin + expectPass bool + }{ + {"basic good", sdk.AccAddress(valAddr1), valAddr2, coinPos, true}, + {"no wrapped msg", nil, nil, coinPos, false}, + {"self bond", sdk.AccAddress(valAddr1), valAddr1, coinPos, true}, + {"empty delegator", sdk.AccAddress(emptyAddr), valAddr1, coinPos, false}, + {"empty validator", sdk.AccAddress(valAddr1), emptyAddr, coinPos, false}, + {"empty bond", sdk.AccAddress(valAddr1), valAddr2, coinZero, false}, + {"nil bold", sdk.AccAddress(valAddr1), valAddr2, sdk.Coin{}, false}, + } + + for _, tc := range tests { + var msg *types.MsgWrappedDelegate + if tc.delegatorAddr == nil { + msg = types.NewMsgWrappedDelegate(nil) + } else { + msgUnwrapped := stakingtypes.NewMsgDelegate(tc.delegatorAddr, tc.validatorAddr, tc.bond) + msg = types.NewMsgWrappedDelegate(msgUnwrapped) + } + if tc.expectPass { + require.NoError(t, msg.ValidateBasic(), "test: %v", tc.name) + } else { + require.Error(t, msg.ValidateBasic(), "test: %v", tc.name) + } + } +} + +// test ValidateBasic for MsgWrappedBeginRedelegate +func TestMsgWrappedBeginRedelegate(t *testing.T) { + tests := []struct { + name string + delegatorAddr sdk.AccAddress + validatorSrcAddr sdk.ValAddress + validatorDstAddr sdk.ValAddress + amount sdk.Coin + expectPass bool + }{ + {"regular", sdk.AccAddress(valAddr1), valAddr2, valAddr3, sdk.NewInt64Coin(sdk.DefaultBondDenom, 1), true}, + {"no wrapped msg", nil, nil, nil, coinPos, false}, + {"zero amount", sdk.AccAddress(valAddr1), valAddr2, valAddr3, sdk.NewInt64Coin(sdk.DefaultBondDenom, 0), false}, + {"nil amount", sdk.AccAddress(valAddr1), valAddr2, valAddr3, sdk.Coin{}, false}, + {"empty delegator", sdk.AccAddress(emptyAddr), valAddr1, valAddr3, sdk.NewInt64Coin(sdk.DefaultBondDenom, 1), false}, + {"empty source validator", sdk.AccAddress(valAddr1), emptyAddr, valAddr3, sdk.NewInt64Coin(sdk.DefaultBondDenom, 1), false}, + {"empty destination validator", sdk.AccAddress(valAddr1), valAddr2, emptyAddr, sdk.NewInt64Coin(sdk.DefaultBondDenom, 1), false}, + } + + for _, tc := range tests { + var msg *types.MsgWrappedBeginRedelegate + if tc.delegatorAddr == nil { + msg = types.NewMsgWrappedBeginRedelegate(nil) + } else { + msgUnwrapped := stakingtypes.NewMsgBeginRedelegate(tc.delegatorAddr, tc.validatorSrcAddr, tc.validatorDstAddr, tc.amount) + msg = types.NewMsgWrappedBeginRedelegate(msgUnwrapped) + } + if tc.expectPass { + require.NoError(t, msg.ValidateBasic(), "test: %v", tc.name) + } else { + require.Error(t, msg.ValidateBasic(), "test: %v", tc.name) + } + } +} + +// test ValidateBasic for MsgWrappedUndelegate +func TestMsgWrappedUndelegate(t *testing.T) { + tests := []struct { + name string + delegatorAddr sdk.AccAddress + validatorAddr sdk.ValAddress + amount sdk.Coin + expectPass bool + }{ + {"regular", sdk.AccAddress(valAddr1), valAddr2, sdk.NewInt64Coin(sdk.DefaultBondDenom, 1), true}, + {"no wrapped msg", nil, nil, coinPos, false}, + {"zero amount", sdk.AccAddress(valAddr1), valAddr2, sdk.NewInt64Coin(sdk.DefaultBondDenom, 0), false}, + {"nil amount", sdk.AccAddress(valAddr1), valAddr2, sdk.Coin{}, false}, + {"empty delegator", sdk.AccAddress(emptyAddr), valAddr1, sdk.NewInt64Coin(sdk.DefaultBondDenom, 1), false}, + {"empty validator", sdk.AccAddress(valAddr1), emptyAddr, sdk.NewInt64Coin(sdk.DefaultBondDenom, 1), false}, + } + + for _, tc := range tests { + var msg *types.MsgWrappedUndelegate + if tc.delegatorAddr == nil { + msg = types.NewMsgWrappedUndelegate(nil) + } else { + msgUnwrapped := stakingtypes.NewMsgUndelegate(tc.delegatorAddr, tc.validatorAddr, tc.amount) + msg = types.NewMsgWrappedUndelegate(msgUnwrapped) + } + if tc.expectPass { + require.NoError(t, msg.ValidateBasic(), "test: %v", tc.name) + } else { + require.Error(t, msg.ValidateBasic(), "test: %v", tc.name) + } + } +}