diff --git a/CHANGELOG.md b/CHANGELOG.md index ef64803cd83a..c0c386dfc648 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,7 +40,8 @@ Ref: https://keepachangelog.com/en/1.0.0/ ### Features -* [#22795](https://github.com/cosmos/cosmos-sdk/pull/22795) `NewUnbondingDelegationEntry`, `NewUnbondingDelegation`, `AddEntry`, `NewRedelegationEntry`, `NewRedelegation` and `NewRedelegationEntryResponse` no longer take an ID in there function signatures. AfterUnbondingInitiated hook has been removed as it is no longer required by ICS. Keys `stakingtypes.UnbondingIDKey, stakingtypes.UnbondingIndexKey, stakingtypes.UnbondingTypeKey` have been removed as they are no longer required by ICS. +* (x/staking) [#22795](https://github.com/crypto-org-chain/cosmos-sdk/pull/1725) Optimize staking endblocker execution by caching queue entries from iterators. +* (x/staking) [#22795](https://github.com/cosmos/cosmos-sdk/pull/22795) `NewUnbondingDelegationEntry`, `NewUnbondingDelegation`, `AddEntry`, `NewRedelegationEntry`, `NewRedelegation` and `NewRedelegationEntryResponse` no longer take an ID in there function signatures. AfterUnbondingInitiated hook has been removed as it is no longer required by ICS. Keys `stakingtypes.UnbondingIDKey, stakingtypes.UnbondingIndexKey, stakingtypes.UnbondingTypeKey` have been removed as they are no longer required by ICS. * (baseapp) [#205](https://github.com/crypto-org-chain/cosmos-sdk/pull/205) Add `TxExecutor` baseapp option, add `TxIndex`/`TxCount`/`MsgIndex`/`BlockGasUsed` fields to `Context, to support tx parallel execution. * (baseapp) [#206](https://github.com/crypto-org-chain/cosmos-sdk/pull/206) Support mount object store in baseapp, add `ObjectStore` api in context, [#585](https://github.com/crypto-org-chain/cosmos-sdk/pull/585) Skip snapshot for object store. * (bank) [#237](https://github.com/crypto-org-chain/cosmos-sdk/pull/237) Support virtual accounts in sending coins. diff --git a/proto/cosmos/staking/v1beta1/query.proto b/proto/cosmos/staking/v1beta1/query.proto index 9b9812235274..9496f8aeaed2 100644 --- a/proto/cosmos/staking/v1beta1/query.proto +++ b/proto/cosmos/staking/v1beta1/query.proto @@ -384,4 +384,4 @@ message QueryParamsRequest {} message QueryParamsResponse { // params holds all the parameters of this module. Params params = 1 [(gogoproto.nullable) = false, (amino.dont_omitempty) = true]; -} +} \ No newline at end of file diff --git a/server/config/config.go b/server/config/config.go index d97dedc5f54e..ba6479317854 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -168,6 +168,16 @@ type MempoolConfig struct { MaxTxs int `mapstructure:"max-txs"` } +// StakingConfig defines the staking module configuration +type StakingConfig struct { + // CacheSize defines the maximum number of time-based queue entries to cache + // for unbonding validators, unbonding delegations, and redelegations. + // cache-size = 0 means unlimited cache (no size limit). + // cache-size < 0 means the cache is disabled. + // cache-size > 0 sets a size limit for the cache. + CacheSize int `mapstructure:"cache-size"` +} + // State Streaming configuration type ( // StreamingConfig defines application configuration for external streaming services @@ -194,6 +204,7 @@ type Config struct { StateSync StateSyncConfig `mapstructure:"state-sync"` Streaming StreamingConfig `mapstructure:"streaming"` Mempool MempoolConfig `mapstructure:"mempool"` + Staking StakingConfig `mapstructure:"staking"` } // SetMinGasPrices sets the validator's minimum gas prices. @@ -265,6 +276,9 @@ func DefaultConfig() *Config { Mempool: MempoolConfig{ MaxTxs: -1, }, + Staking: StakingConfig{ + CacheSize: 0, + }, } } diff --git a/server/config/toml.go b/server/config/toml.go index b3c20c69a024..7959c6b229d2 100644 --- a/server/config/toml.go +++ b/server/config/toml.go @@ -246,6 +246,18 @@ stop-node-on-err = {{ .Streaming.ABCI.StopNodeOnErr }} # Note, this configuration only applies to SDK built-in app-side mempool # implementations. max-txs = {{ .Mempool.MaxTxs }} + +############################################################################### +### Modules ### +############################################################################### + +[staking] +# cache-size defines the maximum number of time-based queue entries to cache +# for unbonding validators, unbonding delegations, and redelegations. +# cache-size = 0 means unlimited cache (no size limit). +# cache-size < 0 means the cache is disabled. +# cache-size > 0 sets a size limit for the cache. +cache-size = {{ .Staking.CacheSize }} ` var configTemplate *template.Template diff --git a/server/start.go b/server/start.go index 57a25576ac67..4fa392609dd5 100644 --- a/server/start.go +++ b/server/start.go @@ -100,6 +100,9 @@ const ( // mempool flags FlagMempoolMaxTxs = "mempool.max-txs" + // staking flags + FlagStakingCacheSize = "staking.cache-size" + // testnet keys KeyIsTestnet = "is-testnet" KeyNewChainID = "new-chain-ID" @@ -997,6 +1000,7 @@ func addStartNodeFlags(cmd *cobra.Command, opts StartCmdOptions) { cmd.Flags().Uint32(FlagStateSyncSnapshotKeepRecent, 2, "State sync snapshot to keep") cmd.Flags().Bool(FlagDisableIAVLFastNode, false, "Disable fast node for IAVL tree") cmd.Flags().Int(FlagMempoolMaxTxs, mempool.DefaultMaxTx, "Sets MaxTx value for the app-side mempool") + cmd.Flags().Int(FlagStakingCacheSize, 0, "Sets the cache size for staking unbonding queues (0 = unlimited, negative = disabled)") cmd.Flags().Duration(FlagShutdownGrace, 0*time.Second, "On Shutdown, duration to wait for resource clean up") // support old flags name for backwards compatibility diff --git a/simapp/app.go b/simapp/app.go index fd2af63d0636..e535916d34fd 100644 --- a/simapp/app.go +++ b/simapp/app.go @@ -316,8 +316,10 @@ func NewSimApp( } app.txConfig = txConfig + stakingCacheSize := cast.ToInt(appOpts.Get(server.FlagStakingCacheSize)) + app.StakingKeeper = stakingkeeper.NewKeeper( - appCodec, runtime.NewKVStoreService(keys[stakingtypes.StoreKey]), app.AccountKeeper, app.BankKeeper, authtypes.NewModuleAddress(govtypes.ModuleName).String(), authcodec.NewBech32Codec(sdk.Bech32PrefixValAddr), authcodec.NewBech32Codec(sdk.Bech32PrefixConsAddr), + appCodec, runtime.NewKVStoreService(keys[stakingtypes.StoreKey]), app.AccountKeeper, app.BankKeeper, authtypes.NewModuleAddress(govtypes.ModuleName).String(), authcodec.NewBech32Codec(sdk.Bech32PrefixValAddr), authcodec.NewBech32Codec(sdk.Bech32PrefixConsAddr), stakingCacheSize, ) app.MintKeeper = mintkeeper.NewKeeper(appCodec, runtime.NewKVStoreService(keys[minttypes.StoreKey]), app.StakingKeeper, app.AccountKeeper, app.BankKeeper, authtypes.FeeCollectorName, authtypes.NewModuleAddress(govtypes.ModuleName).String()) diff --git a/simapp/app_di.go b/simapp/app_di.go index 71999c12ac04..1891aba1df1b 100644 --- a/simapp/app_di.go +++ b/simapp/app_di.go @@ -6,6 +6,7 @@ import ( "io" dbm "github.com/cosmos/cosmos-db" + "github.com/spf13/cast" clienthelpers "cosmossdk.io/client/v2/helpers" "cosmossdk.io/depinject" @@ -96,6 +97,10 @@ func init() { } } +func ProvideStakingCacheSize(appOpts servertypes.AppOptions) int { + return cast.ToInt(appOpts.Get(server.FlagStakingCacheSize)) +} + // NewSimApp returns a reference to an initialized SimApp. func NewSimApp( logger log.Logger, @@ -158,6 +163,7 @@ func NewSimApp( // custom function that implements the minttypes.InflationCalculationFn // interface. ), + depinject.Provide(ProvideStakingCacheSize), ) ) diff --git a/tests/integration/distribution/keeper/msg_server_test.go b/tests/integration/distribution/keeper/msg_server_test.go index 5c7612dd0df2..bdd46a40b536 100644 --- a/tests/integration/distribution/keeper/msg_server_test.go +++ b/tests/integration/distribution/keeper/msg_server_test.go @@ -107,7 +107,7 @@ func initFixture(t testing.TB) *fixture { log.NewNopLogger(), ) - stakingKeeper := stakingkeeper.NewKeeper(cdc, runtime.NewKVStoreService(keys[stakingtypes.StoreKey]), accountKeeper, bankKeeper, authority.String(), addresscodec.NewBech32Codec(sdk.Bech32PrefixValAddr), addresscodec.NewBech32Codec(sdk.Bech32PrefixConsAddr)) + stakingKeeper := stakingkeeper.NewKeeper(cdc, runtime.NewKVStoreService(keys[stakingtypes.StoreKey]), accountKeeper, bankKeeper, authority.String(), addresscodec.NewBech32Codec(sdk.Bech32PrefixValAddr), addresscodec.NewBech32Codec(sdk.Bech32PrefixConsAddr), 0) distrKeeper := distrkeeper.NewKeeper( cdc, runtime.NewKVStoreService(keys[distrtypes.StoreKey]), accountKeeper, bankKeeper, stakingKeeper, distrtypes.ModuleName, authority.String(), diff --git a/tests/integration/evidence/keeper/infraction_test.go b/tests/integration/evidence/keeper/infraction_test.go index 596bca82418b..3ce059da9e9f 100644 --- a/tests/integration/evidence/keeper/infraction_test.go +++ b/tests/integration/evidence/keeper/infraction_test.go @@ -124,7 +124,7 @@ func initFixture(t testing.TB) *fixture { log.NewNopLogger(), ) - stakingKeeper := stakingkeeper.NewKeeper(cdc, runtime.NewKVStoreService(keys[stakingtypes.StoreKey]), accountKeeper, bankKeeper, authority.String(), addresscodec.NewBech32Codec(sdk.Bech32PrefixValAddr), addresscodec.NewBech32Codec(sdk.Bech32PrefixConsAddr)) + stakingKeeper := stakingkeeper.NewKeeper(cdc, runtime.NewKVStoreService(keys[stakingtypes.StoreKey]), accountKeeper, bankKeeper, authority.String(), addresscodec.NewBech32Codec(sdk.Bech32PrefixValAddr), addresscodec.NewBech32Codec(sdk.Bech32PrefixConsAddr), 0) slashingKeeper := slashingkeeper.NewKeeper(cdc, codec.NewLegacyAmino(), runtime.NewKVStoreService(keys[slashingtypes.StoreKey]), stakingKeeper, authority.String()) diff --git a/tests/integration/gov/keeper/keeper_test.go b/tests/integration/gov/keeper/keeper_test.go index 1d8001878588..b52e2a80d743 100644 --- a/tests/integration/gov/keeper/keeper_test.go +++ b/tests/integration/gov/keeper/keeper_test.go @@ -95,7 +95,7 @@ func initFixture(t testing.TB) *fixture { log.NewNopLogger(), ) - stakingKeeper := stakingkeeper.NewKeeper(cdc, runtime.NewKVStoreService(keys[stakingtypes.StoreKey]), accountKeeper, bankKeeper, authority.String(), addresscodec.NewBech32Codec(sdk.Bech32PrefixValAddr), addresscodec.NewBech32Codec(sdk.Bech32PrefixConsAddr)) + stakingKeeper := stakingkeeper.NewKeeper(cdc, runtime.NewKVStoreService(keys[stakingtypes.StoreKey]), accountKeeper, bankKeeper, authority.String(), addresscodec.NewBech32Codec(sdk.Bech32PrefixValAddr), addresscodec.NewBech32Codec(sdk.Bech32PrefixConsAddr), 0) // set default staking params stakingKeeper.SetParams(newCtx, stakingtypes.DefaultParams()) diff --git a/tests/integration/slashing/keeper/keeper_test.go b/tests/integration/slashing/keeper/keeper_test.go index 74aacdf36859..a12797dc0632 100644 --- a/tests/integration/slashing/keeper/keeper_test.go +++ b/tests/integration/slashing/keeper/keeper_test.go @@ -93,7 +93,7 @@ func initFixture(t testing.TB) *fixture { log.NewNopLogger(), ) - stakingKeeper := stakingkeeper.NewKeeper(cdc, runtime.NewKVStoreService(keys[stakingtypes.StoreKey]), accountKeeper, bankKeeper, authority.String(), addresscodec.NewBech32Codec(sdk.Bech32PrefixValAddr), addresscodec.NewBech32Codec(sdk.Bech32PrefixConsAddr)) + stakingKeeper := stakingkeeper.NewKeeper(cdc, runtime.NewKVStoreService(keys[stakingtypes.StoreKey]), accountKeeper, bankKeeper, authority.String(), addresscodec.NewBech32Codec(sdk.Bech32PrefixValAddr), addresscodec.NewBech32Codec(sdk.Bech32PrefixConsAddr), 0) slashingKeeper := slashingkeeper.NewKeeper(cdc, &codec.LegacyAmino{}, runtime.NewKVStoreService(keys[slashingtypes.StoreKey]), stakingKeeper, authority.String()) diff --git a/tests/integration/staking/keeper/common_test.go b/tests/integration/staking/keeper/common_test.go index b5e2815efb27..0ac1d16195cd 100644 --- a/tests/integration/staking/keeper/common_test.go +++ b/tests/integration/staking/keeper/common_test.go @@ -137,7 +137,7 @@ func initFixture(t testing.TB) *fixture { log.NewNopLogger(), ) - stakingKeeper := stakingkeeper.NewKeeper(cdc, runtime.NewKVStoreService(keys[types.StoreKey]), accountKeeper, bankKeeper, authority.String(), addresscodec.NewBech32Codec(sdk.Bech32PrefixValAddr), addresscodec.NewBech32Codec(sdk.Bech32PrefixConsAddr)) + stakingKeeper := stakingkeeper.NewKeeper(cdc, runtime.NewKVStoreService(keys[types.StoreKey]), accountKeeper, bankKeeper, authority.String(), addresscodec.NewBech32Codec(sdk.Bech32PrefixValAddr), addresscodec.NewBech32Codec(sdk.Bech32PrefixConsAddr), 0) authModule := auth.NewAppModule(cdc, accountKeeper, authsims.RandomGenesisAccounts, nil) bankModule := bank.NewAppModule(cdc, bankKeeper, accountKeeper, nil) diff --git a/tests/integration/staking/keeper/determinstic_test.go b/tests/integration/staking/keeper/determinstic_test.go index cc3c59c41dd8..72ce34dafeca 100644 --- a/tests/integration/staking/keeper/determinstic_test.go +++ b/tests/integration/staking/keeper/determinstic_test.go @@ -108,7 +108,7 @@ func initDeterministicFixture(t *testing.T) *deterministicFixture { log.NewNopLogger(), ) - stakingKeeper := stakingkeeper.NewKeeper(cdc, runtime.NewKVStoreService(keys[stakingtypes.StoreKey]), accountKeeper, bankKeeper, authority.String(), addresscodec.NewBech32Codec(sdk.Bech32PrefixValAddr), addresscodec.NewBech32Codec(sdk.Bech32PrefixConsAddr)) + stakingKeeper := stakingkeeper.NewKeeper(cdc, runtime.NewKVStoreService(keys[stakingtypes.StoreKey]), accountKeeper, bankKeeper, authority.String(), addresscodec.NewBech32Codec(sdk.Bech32PrefixValAddr), addresscodec.NewBech32Codec(sdk.Bech32PrefixConsAddr), 0) authModule := auth.NewAppModule(cdc, accountKeeper, authsims.RandomGenesisAccounts, nil) bankModule := bank.NewAppModule(cdc, bankKeeper, accountKeeper, nil) @@ -852,4 +852,4 @@ func TestGRPCParams(t *testing.T) { assert.NilError(t, err) testdata.DeterministicIterations(f.ctx, t, &stakingtypes.QueryParamsRequest{}, f.queryClient.Params, 1114, false) -} \ No newline at end of file +} diff --git a/x/staking/cache/cache.go b/x/staking/cache/cache.go new file mode 100644 index 000000000000..7ba0c5dc3e57 --- /dev/null +++ b/x/staking/cache/cache.go @@ -0,0 +1,340 @@ +package cache + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "cosmossdk.io/log" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/x/staking/types" +) + +type slice[T any] interface { + ~[]T +} + +type CacheEntry[K comparable, V slice[T], T any] struct { + mu sync.RWMutex + data map[K]V + // indicates if the cache requires a reload from the store. + dirty atomic.Bool + // indicates if the cache is full. + full atomic.Bool + // max defines the maximum number of entries in each cache map + // to prevent OOM attacks. + // if the size is 0, the cache is unlimited. + max uint + + loadFromStore func(ctx context.Context) (map[K]V, error) +} + +func NewCacheEntry[K comparable, V slice[T], T any](max uint, loadFromStore func(ctx context.Context) (map[K]V, error)) *CacheEntry[K, V, T] { + entry := &CacheEntry[K, V, T]{max: max, loadFromStore: loadFromStore} + entry.dirty.Store(true) + return entry +} + +func (e *CacheEntry[K, V, T]) get() map[K]V { + e.mu.RLock() + defer e.mu.RUnlock() + + copied := make(map[K]V, len(e.data)) + + if e.data == nil { + return copied + } + + for k, v := range e.data { + sliceCopy := make([]T, len(v)) + copy(sliceCopy, v) + copied[k] = sliceCopy + } + + return copied +} + +func (e *CacheEntry[K, V, T]) getEntry(key K) V { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.data == nil { + return make([]T, 0) + } + + value, exists := e.data[key] + if !exists { + return make([]T, 0) + } + + sliceCopy := make([]T, len(value)) + copy(sliceCopy, value) + return sliceCopy +} + +func (e *CacheEntry[K, V, T]) setEntry(key K, value V) { + e.mu.Lock() + defer e.mu.Unlock() + + if e.full.Load() { + return + } + + if e.data == nil { + e.data = make(map[K]V) + } + + sliceCopy := make([]T, len(value)) + copy(sliceCopy, value) + e.data[key] = sliceCopy + + if e.max > 0 && uint(len(e.data)) == e.max { + e.full.Store(true) + } +} + +func (e *CacheEntry[K, V, T]) deleteEntry(key K) { + e.mu.Lock() + defer e.mu.Unlock() + + if e.data == nil { + return + } + + delete(e.data, key) + if e.max > 0 && uint(len(e.data)) < e.max { + e.full.Store(false) + } +} + +func (e *CacheEntry[K, V, T]) clear() { + e.mu.Lock() + defer e.mu.Unlock() + e.data = make(map[K]V) + e.full.Store(false) +} + +type ValidatorsQueueCache struct { + unbondingValidatorsQueue *CacheEntry[string, []string, string] + unbondingDelegationsQueue *CacheEntry[string, []types.DVPair, types.DVPair] + redelegationsQueue *CacheEntry[string, []types.DVVTriplet, types.DVVTriplet] + logger func(ctx context.Context) log.Logger +} + +func NewValidatorsQueueCache( + size uint, + logger func(ctx context.Context) log.Logger, + loadUnbondingValidators func(ctx context.Context) (map[string][]string, error), + loadUnbondingDelegations func(ctx context.Context) (map[string][]types.DVPair, error), + loadRedelegations func(ctx context.Context) (map[string][]types.DVVTriplet, error), +) *ValidatorsQueueCache { + return NewCache( + NewCacheEntry(size, loadUnbondingValidators), + NewCacheEntry(size, loadUnbondingDelegations), + NewCacheEntry(size, loadRedelegations), + logger, + ) +} + +func NewCache( + unbondingValidatorsQueue *CacheEntry[string, []string, string], + unbondingDelegationsQueue *CacheEntry[string, []types.DVPair, types.DVPair], + redelegationsQueue *CacheEntry[string, []types.DVVTriplet, types.DVVTriplet], + logger func(ctx context.Context) log.Logger, +) *ValidatorsQueueCache { + return &ValidatorsQueueCache{ + unbondingValidatorsQueue: unbondingValidatorsQueue, + unbondingDelegationsQueue: unbondingDelegationsQueue, + redelegationsQueue: redelegationsQueue, + logger: logger, + } +} + +func (c *ValidatorsQueueCache) loadUnbondingValidatorsQueue(ctx context.Context) error { + data, err := c.unbondingValidatorsQueue.loadFromStore(ctx) + if err != nil { + return err + } + + c.unbondingValidatorsQueue.clear() + + for key, value := range data { + c.unbondingValidatorsQueue.setEntry(key, value) + if c.unbondingValidatorsQueue.full.Load() { + return types.ErrCacheMaxSizeReached + } + } + c.unbondingValidatorsQueue.dirty.Store(false) + return nil +} + +func (c *ValidatorsQueueCache) GetUnbondingValidatorsQueue(ctx context.Context) (map[string][]string, error) { + if c.unbondingValidatorsQueue.full.Load() { + return nil, types.ErrCacheMaxSizeReached + } + + if c.unbondingValidatorsQueue.dirty.Load() { + c.logger(ctx).Info("Unbonding validators queue is dirty. Reinitializing cache from store.") + err := c.loadUnbondingValidatorsQueue(ctx) + if err != nil { + return nil, err + } + } + + return c.unbondingValidatorsQueue.get(), nil +} + +func (c *ValidatorsQueueCache) GetUnbondingValidatorsQueueEntry(ctx context.Context, endTime time.Time, endHeight int64) ([]string, error) { + if c.unbondingValidatorsQueue.full.Load() { + return nil, types.ErrCacheMaxSizeReached + } + + if c.unbondingValidatorsQueue.dirty.Load() { + c.logger(ctx).Info("Unbonding validators queue is dirty. Reinitializing cache from store.") + err := c.loadUnbondingValidatorsQueue(ctx) + if err != nil { + return nil, err + } + } + + return c.unbondingValidatorsQueue.getEntry(types.GetCacheValidatorQueueKey(endTime, endHeight)), nil +} + +func (c *ValidatorsQueueCache) SetUnbondingValidatorQueueEntry(ctx context.Context, key string, addrs []string) error { + if c.unbondingValidatorsQueue.full.Load() { + c.unbondingValidatorsQueue.dirty.Store(true) + return types.ErrCacheMaxSizeReached + } + c.unbondingValidatorsQueue.setEntry(key, addrs) + return nil +} + +func (c *ValidatorsQueueCache) DeleteUnbondingValidatorQueueEntry(key string) { + c.unbondingValidatorsQueue.deleteEntry(key) +} + +func (c *ValidatorsQueueCache) loadUnbondingDelegationsQueue(ctx context.Context) error { + data, err := c.unbondingDelegationsQueue.loadFromStore(ctx) + if err != nil { + return err + } + + c.unbondingDelegationsQueue.clear() + + for key, value := range data { + c.unbondingDelegationsQueue.setEntry(key, value) + if c.unbondingDelegationsQueue.full.Load() { + return types.ErrCacheMaxSizeReached + } + } + c.unbondingDelegationsQueue.dirty.Store(false) + return nil +} + +func (c *ValidatorsQueueCache) GetUnbondingDelegationsQueue(ctx context.Context) (map[string][]types.DVPair, error) { + if c.unbondingDelegationsQueue.full.Load() { + return nil, types.ErrCacheMaxSizeReached + } + + if c.unbondingDelegationsQueue.dirty.Load() { + c.logger(ctx).Info("Unbonding delegations queue is dirty. Reinitializing cache from store.") + err := c.loadUnbondingDelegationsQueue(ctx) + if err != nil { + return nil, err + } + } + + return c.unbondingDelegationsQueue.get(), nil +} + +func (c *ValidatorsQueueCache) GetUnbondingDelegationsQueueEntry(ctx context.Context, endTime time.Time) ([]types.DVPair, error) { + if c.unbondingDelegationsQueue.full.Load() { + return nil, types.ErrCacheMaxSizeReached + } + + if c.unbondingDelegationsQueue.dirty.Load() { + err := c.loadUnbondingDelegationsQueue(ctx) + if err != nil { + return nil, err + } + } + + return c.unbondingDelegationsQueue.getEntry(sdk.FormatTimeString(endTime)), nil +} + +func (c *ValidatorsQueueCache) SetUnbondingDelegationsQueueEntry(ctx context.Context, key string, delegations []types.DVPair) error { + if c.unbondingDelegationsQueue.full.Load() { + c.unbondingDelegationsQueue.dirty.Store(true) + return types.ErrCacheMaxSizeReached + } + c.unbondingDelegationsQueue.setEntry(key, delegations) + return nil +} + +func (c *ValidatorsQueueCache) DeleteUnbondingDelegationQueueEntry(key string) { + c.unbondingDelegationsQueue.deleteEntry(key) +} + +func (c *ValidatorsQueueCache) loadRedelegationsQueue(ctx context.Context) error { + data, err := c.redelegationsQueue.loadFromStore(ctx) + if err != nil { + return err + } + + c.redelegationsQueue.clear() + + for key, value := range data { + c.redelegationsQueue.setEntry(key, value) + if c.redelegationsQueue.full.Load() { + return types.ErrCacheMaxSizeReached + } + } + c.redelegationsQueue.dirty.Store(false) + return nil +} + +func (c *ValidatorsQueueCache) GetRedelegationsQueue(ctx context.Context) (map[string][]types.DVVTriplet, error) { + if c.redelegationsQueue.full.Load() { + return nil, types.ErrCacheMaxSizeReached + } + + if c.redelegationsQueue.dirty.Load() { + c.logger(ctx).Info("Redelegations queue is dirty. Reinitializing cache from store.") + err := c.loadRedelegationsQueue(ctx) + if err != nil { + return nil, err + } + } + + return c.redelegationsQueue.get(), nil +} + +func (c *ValidatorsQueueCache) GetRedelegationsQueueEntry(ctx context.Context, endTime time.Time) ([]types.DVVTriplet, error) { + if c.redelegationsQueue.full.Load() { + return nil, types.ErrCacheMaxSizeReached + } + + if c.redelegationsQueue.dirty.Load() { + c.logger(ctx).Info("Redelegations queue is dirty. Reinitializing cache from store.") + err := c.loadRedelegationsQueue(ctx) + if err != nil { + return nil, err + } + } + + return c.redelegationsQueue.getEntry(sdk.FormatTimeString(endTime)), nil +} + +func (c *ValidatorsQueueCache) SetRedelegationsQueueEntry(ctx context.Context, key string, redelegations []types.DVVTriplet) error { + if c.redelegationsQueue.full.Load() { + c.redelegationsQueue.dirty.Store(true) + return types.ErrCacheMaxSizeReached + } + c.redelegationsQueue.setEntry(key, redelegations) + return nil +} + +func (c *ValidatorsQueueCache) DeleteRedelegationsQueueEntry(key string) { + c.redelegationsQueue.deleteEntry(key) +} diff --git a/x/staking/cache/cache_test.go b/x/staking/cache/cache_test.go new file mode 100644 index 000000000000..a0427dc30011 --- /dev/null +++ b/x/staking/cache/cache_test.go @@ -0,0 +1,1019 @@ +package cache + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "cosmossdk.io/log" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/x/staking/types" + "github.com/stretchr/testify/require" +) + +func TestCacheEntry_BasicOperations(t *testing.T) { + loadFunc := func(ctx context.Context) (map[string][]string, error) { + return map[string][]string{ + "key1": {"val1", "val2"}, + "key2": {"val3"}, + }, nil + } + cache := NewCacheEntry(100, loadFunc) + + // Initially empty (dirty=true, data=nil) + data := cache.get() + require.Empty(t, data) + + // Set some data + cache.setEntry("key1", []string{"val1"}) + cache.setEntry("key2", []string{"val2", "val3"}) + + retrieved := cache.get() + require.Len(t, retrieved, 2) + require.Equal(t, []string{"val1"}, retrieved["key1"]) + require.Equal(t, []string{"val2", "val3"}, retrieved["key2"]) + + // Verify it's a copy (modifying returned data shouldn't affect cache) + retrieved["key1"][0] = "modified" + retrievedAgain := cache.get() + require.Equal(t, "val1", retrievedAgain["key1"][0]) +} + +func TestCacheEntry_GetEntry(t *testing.T) { + cache := NewCacheEntry[string, []string](100, nil) + + cache.setEntry("key1", []string{"val1", "val2"}) + cache.setEntry("key2", []string{"val3"}) + + // Get specific entry + entry := cache.getEntry("key1") + require.Equal(t, []string{"val1", "val2"}, entry) + + // Get non-existent entry + nonExistent := cache.getEntry("key3") + require.Empty(t, nonExistent) + + // Verify it's a copy + entry[0] = "modified" + entryAgain := cache.getEntry("key1") + require.Equal(t, "val1", entryAgain[0]) +} + +func TestCacheEntry_SetEntry(t *testing.T) { + cache := NewCacheEntry[string, []string](100, nil) + + cache.setEntry("key1", []string{"val1"}) + cache.setEntry("key2", []string{"val2", "val3"}) + + data := cache.get() + require.Len(t, data, 2) + require.Equal(t, []string{"val1"}, data["key1"]) + require.Equal(t, []string{"val2", "val3"}, data["key2"]) +} + +func TestCacheEntry_DeleteEntry(t *testing.T) { + cache := NewCacheEntry[string, []string](100, nil) + + cache.setEntry("key1", []string{"val1"}) + cache.setEntry("key2", []string{"val2"}) + + data := cache.get() + require.Len(t, data, 2) + + cache.deleteEntry("key1") + data = cache.get() + require.Len(t, data, 1) + require.NotContains(t, data, "key1") + require.Contains(t, data, "key2") +} + +func TestCacheEntry_UnlimitedSize(t *testing.T) { + cache := NewCacheEntry[string, []string](0, nil) + + // With max=0, cache is unlimited (can store anything) + cache.setEntry("key1", []string{"val1"}) + cache.setEntry("key2", []string{"val2"}) + cache.setEntry("key3", []string{"val3"}) + + data := cache.get() + require.Len(t, data, 3) + require.False(t, cache.full.Load(), "unlimited cache should never be full") + + // Add many more entries to verify it's truly unlimited + for i := 0; i < 1000; i++ { + cache.setEntry(fmt.Sprintf("key%d", i+10), []string{fmt.Sprintf("val%d", i)}) + } + + data = cache.get() + require.GreaterOrEqual(t, len(data), 1000) + require.False(t, cache.full.Load(), "unlimited cache should never be full") +} + +func TestCacheEntry_MaxSizeExceeded(t *testing.T) { + cache := NewCacheEntry[string, []string](3, nil) + + // Add entries up to limit + cache.setEntry("key1", []string{"val1"}) + cache.setEntry("key2", []string{"val2"}) + cache.setEntry("key3", []string{"val3"}) + + require.True(t, cache.full.Load()) + require.Len(t, cache.get(), 3) + + // Try to add one more - should set full flag + cache.setEntry("key4", []string{"val4"}) + + require.True(t, cache.full.Load(), "cache should be marked as full") + // key4 was not added because cache is now full + data := cache.get() + require.Len(t, data, 3) + require.NotContains(t, data, "key4") +} + +func TestCacheEntry_FullFlagPreventsWrites(t *testing.T) { + cache := NewCacheEntry[string, []string](2, nil) + + cache.setEntry("key1", []string{"val1"}) + cache.setEntry("key2", []string{"val2"}) + require.True(t, cache.full.Load()) + + // Try to add more - should be ignored + cache.setEntry("key4", []string{"val4"}) + data := cache.get() + require.Len(t, data, 2) + require.NotContains(t, data, "key4") + + // Try to edit existing entry - should be ignored + cache.setEntry("key2", []string{"val2", "val3"}) + data = cache.get() + require.Len(t, data, 2) +} + +func TestCacheEntry_DeleteClearsFull(t *testing.T) { + cache := NewCacheEntry[string, []string](2, nil) + + // Fill to capacity + cache.setEntry("key1", []string{"val1"}) + cache.setEntry("key2", []string{"val2"}) + cache.setEntry("key3", []string{"val3"}) // Triggers full + require.True(t, cache.full.Load()) + + // Delete one entry - should clear full flag + cache.deleteEntry("key1") + require.False(t, cache.full.Load(), "deleting should clear full flag") + + // Now we should be able to add again + cache.setEntry("key4", []string{"val4"}) + data := cache.get() + require.Len(t, data, 2) + require.Contains(t, data, "key4") +} + +// Test ValidatorsQueueCache +func TestValidatorsQueueCache_Initialization(t *testing.T) { + validatorsLoader := func(ctx context.Context) (map[string][]string, error) { + return map[string][]string{ + "time1": {"val1", "val2"}, + "time2": {"val3"}, + }, nil + } + delegationsLoader := func(ctx context.Context) (map[string][]types.DVPair, error) { + return map[string][]types.DVPair{ + "time1": {{DelegatorAddress: "del1", ValidatorAddress: "val1"}}, + }, nil + } + redelegationsLoader := func(ctx context.Context) (map[string][]types.DVVTriplet, error) { + return map[string][]types.DVVTriplet{ + "time1": {{DelegatorAddress: "del1", ValidatorSrcAddress: "val1", ValidatorDstAddress: "val2"}}, + }, nil + } + + logger := func(ctx context.Context) log.Logger { + return log.NewNopLogger() + } + + cache := NewValidatorsQueueCache( + 100, + logger, + validatorsLoader, + delegationsLoader, + redelegationsLoader, + ) + + require.NotNil(t, cache) + require.NotNil(t, cache.unbondingValidatorsQueue) + require.NotNil(t, cache.unbondingDelegationsQueue) + require.NotNil(t, cache.redelegationsQueue) +} + +func TestValidatorsQueueCache_LoadFromStore(t *testing.T) { + ctx := context.Background() + + validatorsLoader := func(ctx context.Context) (map[string][]string, error) { + return map[string][]string{ + "time1": {"val1", "val2"}, + "time2": {"val3"}, + }, nil + } + + delegationsLoader := func(ctx context.Context) (map[string][]types.DVPair, error) { + return map[string][]types.DVPair{ + "time1": {{DelegatorAddress: "del1", ValidatorAddress: "val1"}}, + }, nil + } + redelegationsLoader := func(ctx context.Context) (map[string][]types.DVVTriplet, error) { + return map[string][]types.DVVTriplet{ + "time1": {{DelegatorAddress: "del1", ValidatorSrcAddress: "val1", ValidatorDstAddress: "val2"}}, + }, nil + } + + logger := func(ctx context.Context) log.Logger { + return log.NewNopLogger() + } + + cache := NewValidatorsQueueCache( + 100, + logger, + validatorsLoader, + delegationsLoader, + redelegationsLoader, + ) + + // Initially dirty, should load from store + unbondingValidators, err := cache.GetUnbondingValidatorsQueue(ctx) + require.NoError(t, err) + require.Len(t, unbondingValidators, 2) + require.Equal(t, []string{"val1", "val2"}, unbondingValidators["time1"]) + require.Equal(t, []string{"val3"}, unbondingValidators["time2"]) + + unbondingDelegations, err := cache.GetUnbondingDelegationsQueue(ctx) + require.NoError(t, err) + require.Len(t, unbondingDelegations, 1) + require.Equal(t, []types.DVPair{{DelegatorAddress: "del1", ValidatorAddress: "val1"}}, unbondingDelegations["time1"]) + + redelgations, err := cache.GetRedelegationsQueue(ctx) + require.NoError(t, err) + require.Len(t, redelgations, 1) + require.Equal(t, []types.DVVTriplet{{DelegatorAddress: "del1", ValidatorSrcAddress: "val1", ValidatorDstAddress: "val2"}}, redelgations["time1"]) + + // Cache should no longer be dirty + require.False(t, cache.unbondingValidatorsQueue.dirty.Load()) + require.False(t, cache.unbondingDelegationsQueue.dirty.Load()) + require.False(t, cache.redelegationsQueue.dirty.Load()) +} + +func TestValidatorsQueueCache_DirtyReinitialization(t *testing.T) { + ctx := context.Background() + + // Test unbonding validators + valCallCount := 0 + validatorsLoader := func(ctx context.Context) (map[string][]string, error) { + valCallCount++ + if valCallCount == 1 { + return map[string][]string{ + "time1": {"val1"}, + }, nil + } + return map[string][]string{ + "time2": {"val2"}, + }, nil + } + + // Test unbonding delegations + delCallCount := 0 + delegationsLoader := func(ctx context.Context) (map[string][]types.DVPair, error) { + delCallCount++ + if delCallCount == 1 { + return map[string][]types.DVPair{ + "time1": {{DelegatorAddress: "del1", ValidatorAddress: "val1"}}, + }, nil + } + return map[string][]types.DVPair{ + "time2": {{DelegatorAddress: "del2", ValidatorAddress: "val2"}}, + }, nil + } + + // Test redelegations + redCallCount := 0 + redelegationsLoader := func(ctx context.Context) (map[string][]types.DVVTriplet, error) { + redCallCount++ + if redCallCount == 1 { + return map[string][]types.DVVTriplet{ + "time1": {{DelegatorAddress: "del1", ValidatorSrcAddress: "val1", ValidatorDstAddress: "val2"}}, + }, nil + } + return map[string][]types.DVVTriplet{ + "time2": {{DelegatorAddress: "del2", ValidatorSrcAddress: "val2", ValidatorDstAddress: "val3"}}, + }, nil + } + + logger := func(ctx context.Context) log.Logger { + return log.NewNopLogger() + } + + cache := NewValidatorsQueueCache( + 100, + logger, + validatorsLoader, + delegationsLoader, + redelegationsLoader, + ) + + // Test unbonding validators queue + valData, err := cache.GetUnbondingValidatorsQueue(ctx) + require.NoError(t, err) + require.Len(t, valData, 1) + require.Contains(t, valData, "time1") + require.Equal(t, 1, valCallCount) + + cache.unbondingValidatorsQueue.dirty.Store(true) + valData, err = cache.GetUnbondingValidatorsQueue(ctx) + require.NoError(t, err) + require.Len(t, valData, 1) + require.Contains(t, valData, "time2") + require.NotContains(t, valData, "time1", "old data should be cleared") + require.Equal(t, 2, valCallCount) + + // Test unbonding delegations queue + delData, err := cache.GetUnbondingDelegationsQueue(ctx) + require.NoError(t, err) + require.Len(t, delData, 1) + require.Contains(t, delData, "time1") + require.Equal(t, 1, delCallCount) + + cache.unbondingDelegationsQueue.dirty.Store(true) + delData, err = cache.GetUnbondingDelegationsQueue(ctx) + require.NoError(t, err) + require.Len(t, delData, 1) + require.Contains(t, delData, "time2") + require.NotContains(t, delData, "time1", "old data should be cleared") + require.Equal(t, 2, delCallCount) + + // Test redelegations queue + redData, err := cache.GetRedelegationsQueue(ctx) + require.NoError(t, err) + require.Len(t, redData, 1) + require.Contains(t, redData, "time1") + require.Equal(t, 1, redCallCount) + + cache.redelegationsQueue.dirty.Store(true) + redData, err = cache.GetRedelegationsQueue(ctx) + require.NoError(t, err) + require.Len(t, redData, 1) + require.Contains(t, redData, "time2") + require.NotContains(t, redData, "time1", "old data should be cleared") + require.Equal(t, 2, redCallCount) +} + +func TestValidatorsQueueCache_FullPreventsLoad(t *testing.T) { + ctx := context.Background() + + validatorsLoader := func(ctx context.Context) (map[string][]string, error) { + // Return too much data (4 keys > max 3) + return map[string][]string{ + "time1": {"val1"}, + "time2": {"val2"}, + "time3": {"val3"}, + "time4": {"val4"}, + }, nil + } + + delegationsLoader := func(ctx context.Context) (map[string][]types.DVPair, error) { + // Return too much data (4 keys > max 3) + return map[string][]types.DVPair{ + "time1": {{DelegatorAddress: "del1", ValidatorAddress: "val1"}}, + "time2": {{DelegatorAddress: "del2", ValidatorAddress: "val2"}}, + "time3": {{DelegatorAddress: "del3", ValidatorAddress: "val3"}}, + "time4": {{DelegatorAddress: "del4", ValidatorAddress: "val4"}}, + }, nil + } + + redelegationsLoader := func(ctx context.Context) (map[string][]types.DVVTriplet, error) { + // Return too much data (4 keys > max 3) + return map[string][]types.DVVTriplet{ + "time1": {{DelegatorAddress: "del1", ValidatorSrcAddress: "val1", ValidatorDstAddress: "val2"}}, + "time2": {{DelegatorAddress: "del2", ValidatorSrcAddress: "val2", ValidatorDstAddress: "val3"}}, + "time3": {{DelegatorAddress: "del3", ValidatorSrcAddress: "val3", ValidatorDstAddress: "val4"}}, + "time4": {{DelegatorAddress: "del4", ValidatorSrcAddress: "val4", ValidatorDstAddress: "val5"}}, + }, nil + } + + logger := func(ctx context.Context) log.Logger { + return log.NewNopLogger() + } + + cache := NewValidatorsQueueCache( + 3, + logger, + validatorsLoader, + delegationsLoader, + redelegationsLoader, + ) + + // Try to load unbonding validators - should fail due to exceeding max + _, err := cache.GetUnbondingValidatorsQueue(ctx) + require.Error(t, err) + require.Equal(t, types.ErrCacheMaxSizeReached, err) + require.True(t, cache.unbondingValidatorsQueue.full.Load()) + + // Try to load unbonding delegations - should fail due to exceeding max + _, err = cache.GetUnbondingDelegationsQueue(ctx) + require.Error(t, err) + require.Equal(t, types.ErrCacheMaxSizeReached, err) + require.True(t, cache.unbondingDelegationsQueue.full.Load()) + + // Try to load redelegations - should fail due to exceeding max + _, err = cache.GetRedelegationsQueue(ctx) + require.Error(t, err) + require.Equal(t, types.ErrCacheMaxSizeReached, err) + require.True(t, cache.redelegationsQueue.full.Load()) +} + +func TestValidatorsQueueCache_GetEntry(t *testing.T) { + ctx := context.Background() + + logger := func(ctx context.Context) log.Logger { + return log.NewNopLogger() + } + + cache := NewValidatorsQueueCache( + 100, + logger, + nil, + nil, + nil, + ) + + // Clear dirty flags to avoid loading + cache.unbondingValidatorsQueue.dirty.Store(false) + cache.unbondingDelegationsQueue.dirty.Store(false) + cache.redelegationsQueue.dirty.Store(false) + + // Test unbonding validators queue entry + endTime := time.Now().UTC() + endHeight := int64(1000) + valKey := types.GetCacheValidatorQueueKey(endTime, endHeight) + + cache.SetUnbondingValidatorQueueEntry(ctx, valKey, []string{"val1", "val2"}) + valEntry, err := cache.GetUnbondingValidatorsQueueEntry(ctx, endTime, endHeight) + require.NoError(t, err) + require.Equal(t, []string{"val1", "val2"}, valEntry) + + // Test unbonding delegations queue entry + delKey := sdk.FormatTimeString(endTime) + delPairs := []types.DVPair{ + {DelegatorAddress: "del1", ValidatorAddress: "val1"}, + {DelegatorAddress: "del2", ValidatorAddress: "val2"}, + } + cache.SetUnbondingDelegationsQueueEntry(ctx, delKey, delPairs) + delEntry, err := cache.GetUnbondingDelegationsQueueEntry(ctx, endTime) + require.NoError(t, err) + require.Equal(t, delPairs, delEntry) + + // Test redelegations queue entry + redKey := sdk.FormatTimeString(endTime) + redTriplets := []types.DVVTriplet{ + {DelegatorAddress: "del1", ValidatorSrcAddress: "val1", ValidatorDstAddress: "val2"}, + {DelegatorAddress: "del2", ValidatorSrcAddress: "val2", ValidatorDstAddress: "val3"}, + } + cache.SetRedelegationsQueueEntry(ctx, redKey, redTriplets) + redEntry, err := cache.GetRedelegationsQueueEntry(ctx, endTime) + require.NoError(t, err) + require.Equal(t, redTriplets, redEntry) +} + +func TestValidatorsQueueCache_SetAndDelete(t *testing.T) { + ctx := context.Background() + + logger := func(ctx context.Context) log.Logger { + return log.NewNopLogger() + } + + cache := NewValidatorsQueueCache( + 100, + logger, + nil, + nil, + nil, + ) + + // Clear dirty flags to avoid loading + cache.unbondingValidatorsQueue.dirty.Store(false) + cache.unbondingDelegationsQueue.dirty.Store(false) + cache.redelegationsQueue.dirty.Store(false) + + // Test unbonding validators queue + err := cache.SetUnbondingValidatorQueueEntry(ctx, "key1", []string{"val1"}) + require.NoError(t, err) + err = cache.SetUnbondingValidatorQueueEntry(ctx, "key2", []string{"val2"}) + require.NoError(t, err) + + valData, err := cache.GetUnbondingValidatorsQueue(ctx) + require.NoError(t, err) + require.Len(t, valData, 2) + + cache.DeleteUnbondingValidatorQueueEntry("key1") + valData, err = cache.GetUnbondingValidatorsQueue(ctx) + require.NoError(t, err) + require.Len(t, valData, 1) + require.NotContains(t, valData, "key1") + + // Test unbonding delegations queue + err = cache.SetUnbondingDelegationsQueueEntry(ctx, "time1", []types.DVPair{ + {DelegatorAddress: "del1", ValidatorAddress: "val1"}, + }) + require.NoError(t, err) + err = cache.SetUnbondingDelegationsQueueEntry(ctx, "time2", []types.DVPair{ + {DelegatorAddress: "del2", ValidatorAddress: "val2"}, + }) + require.NoError(t, err) + + delData, err := cache.GetUnbondingDelegationsQueue(ctx) + require.NoError(t, err) + require.Len(t, delData, 2) + + cache.DeleteUnbondingDelegationQueueEntry("time1") + delData, err = cache.GetUnbondingDelegationsQueue(ctx) + require.NoError(t, err) + require.Len(t, delData, 1) + require.NotContains(t, delData, "time1") + + // Test redelegations queue + err = cache.SetRedelegationsQueueEntry(ctx, "time1", []types.DVVTriplet{ + {DelegatorAddress: "del1", ValidatorSrcAddress: "val1", ValidatorDstAddress: "val2"}, + }) + require.NoError(t, err) + err = cache.SetRedelegationsQueueEntry(ctx, "time2", []types.DVVTriplet{ + {DelegatorAddress: "del2", ValidatorSrcAddress: "val2", ValidatorDstAddress: "val3"}, + }) + require.NoError(t, err) + + redData, err := cache.GetRedelegationsQueue(ctx) + require.NoError(t, err) + require.Len(t, redData, 2) + + cache.DeleteRedelegationsQueueEntry("time1") + redData, err = cache.GetRedelegationsQueue(ctx) + require.NoError(t, err) + require.Len(t, redData, 1) + require.NotContains(t, redData, "time1") +} + +func TestValidatorsQueueCache_FullMarkedDirty(t *testing.T) { + ctx := context.Background() + + logger := func(ctx context.Context) log.Logger { + return log.NewNopLogger() + } + + cache := NewValidatorsQueueCache( + 2, + logger, + nil, + nil, + nil, + ) + + // Clear dirty flags + cache.unbondingValidatorsQueue.dirty.Store(false) + cache.unbondingDelegationsQueue.dirty.Store(false) + cache.redelegationsQueue.dirty.Store(false) + + // Test unbonding validators queue + err := cache.SetUnbondingValidatorQueueEntry(ctx, "key1", []string{"val1"}) + require.NoError(t, err) + err = cache.SetUnbondingValidatorQueueEntry(ctx, "key2", []string{"val2"}) + require.NoError(t, err) + + // Try to add one more - should mark as full and dirty + err = cache.SetUnbondingValidatorQueueEntry(ctx, "key3", []string{"val3"}) + require.Error(t, err) + require.Equal(t, types.ErrCacheMaxSizeReached, err) + require.True(t, cache.unbondingValidatorsQueue.full.Load()) + require.True(t, cache.unbondingValidatorsQueue.dirty.Load()) + + // Test unbonding delegations queue + err = cache.SetUnbondingDelegationsQueueEntry(ctx, "time1", []types.DVPair{ + {DelegatorAddress: "del1", ValidatorAddress: "val1"}, + }) + require.NoError(t, err) + err = cache.SetUnbondingDelegationsQueueEntry(ctx, "time2", []types.DVPair{ + {DelegatorAddress: "del2", ValidatorAddress: "val2"}, + }) + require.NoError(t, err) + + // Try to add one more - should mark as full and dirty + err = cache.SetUnbondingDelegationsQueueEntry(ctx, "time3", []types.DVPair{ + {DelegatorAddress: "del3", ValidatorAddress: "val3"}, + }) + require.Error(t, err) + require.Equal(t, types.ErrCacheMaxSizeReached, err) + require.True(t, cache.unbondingDelegationsQueue.full.Load()) + require.True(t, cache.unbondingDelegationsQueue.dirty.Load()) + + // Test redelegations queue + err = cache.SetRedelegationsQueueEntry(ctx, "time1", []types.DVVTriplet{ + {DelegatorAddress: "del1", ValidatorSrcAddress: "val1", ValidatorDstAddress: "val2"}, + }) + require.NoError(t, err) + err = cache.SetRedelegationsQueueEntry(ctx, "time2", []types.DVVTriplet{ + {DelegatorAddress: "del2", ValidatorSrcAddress: "val2", ValidatorDstAddress: "val3"}, + }) + require.NoError(t, err) + + // Try to add one more - should mark as full and dirty + err = cache.SetRedelegationsQueueEntry(ctx, "time3", []types.DVVTriplet{ + {DelegatorAddress: "del3", ValidatorSrcAddress: "val3", ValidatorDstAddress: "val4"}, + }) + require.Error(t, err) + require.Equal(t, types.ErrCacheMaxSizeReached, err) + require.True(t, cache.redelegationsQueue.full.Load()) + require.True(t, cache.redelegationsQueue.dirty.Load()) +} + +func TestValidatorsQueueCache_UnbondingDelegations(t *testing.T) { + ctx := context.Background() + + delegationsLoader := func(ctx context.Context) (map[string][]types.DVPair, error) { + return map[string][]types.DVPair{ + "time1": { + {DelegatorAddress: "del1", ValidatorAddress: "val1"}, + {DelegatorAddress: "del2", ValidatorAddress: "val2"}, + }, + }, nil + } + + logger := func(ctx context.Context) log.Logger { + return log.NewNopLogger() + } + + cache := NewValidatorsQueueCache( + 100, + logger, + nil, + delegationsLoader, + nil, + ) + + // Load from store + data, err := cache.GetUnbondingDelegationsQueue(ctx) + require.NoError(t, err) + require.Len(t, data, 1) + require.Len(t, data["time1"], 2) + + // Set individual entry + err = cache.SetUnbondingDelegationsQueueEntry(ctx, "time2", []types.DVPair{ + {DelegatorAddress: "del3", ValidatorAddress: "val3"}, + }) + require.NoError(t, err) + + data, err = cache.GetUnbondingDelegationsQueue(ctx) + require.NoError(t, err) + require.Len(t, data, 2) + + // Delete entry + cache.DeleteUnbondingDelegationQueueEntry("time1") + data, err = cache.GetUnbondingDelegationsQueue(ctx) + require.NoError(t, err) + require.Len(t, data, 1) + require.NotContains(t, data, "time1") +} + +func TestValidatorsQueueCache_UnbondingDelegationsEntry(t *testing.T) { + ctx := context.Background() + + logger := func(ctx context.Context) log.Logger { + return log.NewNopLogger() + } + + cache := NewValidatorsQueueCache( + 100, + logger, + nil, + nil, + nil, + ) + + // Clear dirty flag + cache.unbondingDelegationsQueue.dirty.Store(false) + + endTime := time.Now().UTC() + key := sdk.FormatTimeString(endTime) + + // Set entry + pairs := []types.DVPair{ + {DelegatorAddress: "del1", ValidatorAddress: "val1"}, + } + err := cache.SetUnbondingDelegationsQueueEntry(ctx, key, pairs) + require.NoError(t, err) + + // Get specific entry + entry, err := cache.GetUnbondingDelegationsQueueEntry(ctx, endTime) + require.NoError(t, err) + require.Equal(t, pairs, entry) +} + +func TestValidatorsQueueCache_Redelegations(t *testing.T) { + ctx := context.Background() + + redelegationsLoader := func(ctx context.Context) (map[string][]types.DVVTriplet, error) { + return map[string][]types.DVVTriplet{ + "time1": { + {DelegatorAddress: "del1", ValidatorSrcAddress: "val1", ValidatorDstAddress: "val2"}, + }, + }, nil + } + + logger := func(ctx context.Context) log.Logger { + return log.NewNopLogger() + } + + cache := NewValidatorsQueueCache( + 100, + logger, + nil, + nil, + redelegationsLoader, + ) + + // Load from store + data, err := cache.GetRedelegationsQueue(ctx) + require.NoError(t, err) + require.Len(t, data, 1) + + // Set individual entry + err = cache.SetRedelegationsQueueEntry(ctx, "time2", []types.DVVTriplet{ + {DelegatorAddress: "del2", ValidatorSrcAddress: "val2", ValidatorDstAddress: "val3"}, + }) + require.NoError(t, err) + + data, err = cache.GetRedelegationsQueue(ctx) + require.NoError(t, err) + require.Len(t, data, 2) + + // Delete entry + cache.DeleteRedelegationsQueueEntry("time1") + data, err = cache.GetRedelegationsQueue(ctx) + require.NoError(t, err) + require.Len(t, data, 1) + require.NotContains(t, data, "time1") +} + +func TestValidatorsQueueCache_RedelegationsEntry(t *testing.T) { + ctx := context.Background() + + logger := func(ctx context.Context) log.Logger { + return log.NewNopLogger() + } + + cache := NewValidatorsQueueCache( + 100, + logger, + nil, + nil, + nil, + ) + + // Clear dirty flag + cache.redelegationsQueue.dirty.Store(false) + + endTime := time.Now().UTC() + key := sdk.FormatTimeString(endTime) + + // Set entry + triplets := []types.DVVTriplet{ + {DelegatorAddress: "del1", ValidatorSrcAddress: "val1", ValidatorDstAddress: "val2"}, + } + err := cache.SetRedelegationsQueueEntry(ctx, key, triplets) + require.NoError(t, err) + + // Get specific entry + entry, err := cache.GetRedelegationsQueueEntry(ctx, endTime) + require.NoError(t, err) + require.Equal(t, triplets, entry) +} + +// Concurrent operations tests +func TestCacheEntry_ConcurrentReads(t *testing.T) { + cache := NewCacheEntry[string, []string](1000, nil) + cache.dirty.Store(false) // Skip loading + + for i := 0; i < 100; i++ { + cache.setEntry(fmt.Sprintf("key%d", i), []string{fmt.Sprintf("val%d", i)}) + } + + var wg sync.WaitGroup + numReaders := 50 + readsPerReader := 100 + + wg.Add(numReaders) + for i := 0; i < numReaders; i++ { + go func() { + defer wg.Done() + for j := 0; j < readsPerReader; j++ { + data := cache.get() + require.NotEmpty(t, data) + } + }() + } + + wg.Wait() +} + +func TestCacheEntry_ConcurrentWrites(t *testing.T) { + cache := NewCacheEntry[string, []string](10000, nil) + + var wg sync.WaitGroup + numWriters := 50 + writesPerWriter := 100 + + wg.Add(numWriters) + for i := 0; i < numWriters; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < writesPerWriter; j++ { + key := fmt.Sprintf("key_%d_%d", id, j) + cache.setEntry(key, []string{fmt.Sprintf("val_%d", id)}) + } + }(i) + } + + wg.Wait() + + // Verify data integrity + data := cache.get() + require.NotEmpty(t, data) + // Each writer creates unique keys, so we expect exactly numWriters * writesPerWriter entries + expectedKeys := numWriters * writesPerWriter + require.Equal(t, expectedKeys, len(data), "cache should contain exactly %d keys", expectedKeys) +} + +func TestCacheEntry_ConcurrentReadWrite(t *testing.T) { + cache := NewCacheEntry[string, []string](10000, nil) + cache.dirty.Store(false) + + for i := 0; i < 100; i++ { + cache.setEntry(fmt.Sprintf("key%d", i), []string{fmt.Sprintf("val%d", i)}) + } + + var wg sync.WaitGroup + numRoutines := 50 + operationsPerRoutine := 100 + + // Readers + wg.Add(numRoutines) + for i := 0; i < numRoutines; i++ { + go func() { + defer wg.Done() + for j := 0; j < operationsPerRoutine; j++ { + _ = cache.get() + } + }() + } + + // Writers + wg.Add(numRoutines) + for i := 0; i < numRoutines; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < operationsPerRoutine; j++ { + key := fmt.Sprintf("new_key_%d_%d", id, j) + cache.setEntry(key, []string{fmt.Sprintf("val_%d", id)}) + } + }(i) + } + + wg.Wait() + + // Should complete without race conditions + data := cache.get() + require.NotEmpty(t, data) +} + +func TestCacheEntry_ConcurrentSetAndDelete(t *testing.T) { + cache := NewCacheEntry[string, []string](10000, nil) + + var wg sync.WaitGroup + numRoutines := 30 + operationsPerRoutine := 100 + + // Writers + wg.Add(numRoutines) + for i := 0; i < numRoutines; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < operationsPerRoutine; j++ { + key := fmt.Sprintf("key_%d", id%10) // Reuse some keys + cache.setEntry(key, []string{fmt.Sprintf("val_%d_%d", id, j)}) + } + }(i) + } + + // Deleters + wg.Add(numRoutines) + for i := 0; i < numRoutines; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < operationsPerRoutine; j++ { + key := fmt.Sprintf("key_%d", id%10) + cache.deleteEntry(key) + } + }(i) + } + + wg.Wait() + + // Should complete without panic + _ = cache.get() +} + +func TestValidatorsQueueCache_ConcurrentOperations(t *testing.T) { + ctx := context.Background() + + logger := func(ctx context.Context) log.Logger { + return log.NewNopLogger() + } + + cache := NewValidatorsQueueCache( + 10000, + logger, + nil, + nil, + nil, + ) + + // Clear dirty flags + cache.unbondingValidatorsQueue.dirty.Store(false) + cache.unbondingDelegationsQueue.dirty.Store(false) + cache.redelegationsQueue.dirty.Store(false) + + var wg sync.WaitGroup + numRoutines := 30 + operationsPerRoutine := 100 + + // Concurrent unbonding validators operations + wg.Add(numRoutines) + for i := 0; i < numRoutines; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < operationsPerRoutine; j++ { + key := fmt.Sprintf("v_%d_%d", id, j) + cache.SetUnbondingValidatorQueueEntry(ctx, key, []string{fmt.Sprintf("addr_%d", id)}) + } + }(i) + } + + // Concurrent unbonding delegations operations + wg.Add(numRoutines) + for i := 0; i < numRoutines; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < operationsPerRoutine; j++ { + key := fmt.Sprintf("d_%d_%d", id, j) + cache.SetUnbondingDelegationsQueueEntry(ctx, key, []types.DVPair{ + {DelegatorAddress: fmt.Sprintf("del_%d", id), ValidatorAddress: "val"}, + }) + } + }(i) + } + + // Concurrent redelegations operations + wg.Add(numRoutines) + for i := 0; i < numRoutines; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < operationsPerRoutine; j++ { + key := fmt.Sprintf("r_%d_%d", id, j) + cache.SetRedelegationsQueueEntry(ctx, key, []types.DVVTriplet{ + {DelegatorAddress: fmt.Sprintf("del_%d", id), ValidatorSrcAddress: "val1", ValidatorDstAddress: "val2"}, + }) + } + }(i) + } + + // Concurrent readers + wg.Add(numRoutines) + for i := 0; i < numRoutines; i++ { + go func() { + defer wg.Done() + for j := 0; j < operationsPerRoutine; j++ { + _, _ = cache.GetUnbondingValidatorsQueue(ctx) + _, _ = cache.GetUnbondingDelegationsQueue(ctx) + _, _ = cache.GetRedelegationsQueue(ctx) + } + }() + } + + wg.Wait() + + // Verify all caches have data + vData, _ := cache.GetUnbondingValidatorsQueue(ctx) + dData, _ := cache.GetUnbondingDelegationsQueue(ctx) + rData, _ := cache.GetRedelegationsQueue(ctx) + + require.NotEmpty(t, vData) + require.NotEmpty(t, dData) + require.NotEmpty(t, rData) +} diff --git a/x/staking/keeper/delegation.go b/x/staking/keeper/delegation.go index c855ea29a613..3545b34b6963 100644 --- a/x/staking/keeper/delegation.go +++ b/x/staking/keeper/delegation.go @@ -447,6 +447,14 @@ func (k Keeper) SetUnbondingDelegationEntry( // is a slice of DVPairs corresponding to unbonding delegations that expire at a // certain time. func (k Keeper) GetUBDQueueTimeSlice(ctx context.Context, timestamp time.Time) (dvPairs []types.DVPair, err error) { + if k.cache != nil { + cachedPairs, err := k.cache.GetUnbondingDelegationsQueueEntry(ctx, timestamp) + if err == nil { + return cachedPairs, nil + } + k.Logger(ctx).Error("GetUBDQueueTimeSlice from cache failed. Error: %s", err) + } + store := k.storeService.OpenKVStore(ctx) bz, err := store.Get(types.GetUnbondingDelegationTimeKey(timestamp)) @@ -467,7 +475,18 @@ func (k Keeper) SetUBDQueueTimeSlice(ctx context.Context, timestamp time.Time, k if err != nil { return err } - return store.Set(types.GetUnbondingDelegationTimeKey(timestamp), bz) + err = store.Set(types.GetUnbondingDelegationTimeKey(timestamp), bz) + if err != nil { + return err + } + + if k.cache != nil { + err = k.cache.SetUnbondingDelegationsQueueEntry(ctx, sdk.FormatTimeString(timestamp), keys) + if err != nil { + k.Logger(ctx).Error("SetUBDQueueTimeSlice from cache failed. Error: %s", err) + } + } + return nil } // InsertUBDQueue inserts an unbonding delegation to the appropriate timeslice @@ -498,36 +517,116 @@ func (k Keeper) UBDQueueIterator(ctx context.Context, endTime time.Time) (corest storetypes.InclusiveEndBytes(types.GetUnbondingDelegationTimeKey(endTime))) } -// DequeueAllMatureUBDQueue returns a concatenated list of all the timeslices inclusively previous to -// currTime, and deletes the timeslices from the queue. -func (k Keeper) DequeueAllMatureUBDQueue(ctx context.Context, currTime time.Time) (matureUnbonds []types.DVPair, err error) { +// UBDQueueIteratorAll returns all the unbonding queue timeslices. +func (k Keeper) UBDQueueIteratorAll(ctx context.Context) (corestore.Iterator, error) { store := k.storeService.OpenKVStore(ctx) + return store.Iterator(types.UnbondingQueueKey, storetypes.PrefixEndBytes(types.UnbondingQueueKey)) +} - // gets an iterator for all timeslices from time 0 until the current Blockheader time - unbondingTimesliceIterator, err := k.UBDQueueIterator(ctx, currTime) +// DequeueAllMatureUBDQueue returns a concatenated list of all the timeslices, and deletes the matured timeslices from the queue. +func (k Keeper) DequeueAllMatureUBDQueue(ctx context.Context, currTime time.Time) (matureUnbonds []types.DVPair, err error) { + unbondingDelegations, err := k.GetUBDs(ctx, currTime) if err != nil { return matureUnbonds, err } - defer unbondingTimesliceIterator.Close() - for ; unbondingTimesliceIterator.Valid(); unbondingTimesliceIterator.Next() { - timeslice := types.DVPairs{} - value := unbondingTimesliceIterator.Value() - if err = k.cdc.Unmarshal(value, ×lice); err != nil { + keys := make([]string, 0, len(unbondingDelegations)) + + for key := range unbondingDelegations { + keys = append(keys, key) + } + + types.SortTimestampsByAscendingOrder(keys) + + store := k.storeService.OpenKVStore(ctx) + + for _, key := range keys { + t, err := sdk.ParseTime(key) + if err != nil { return matureUnbonds, err } - matureUnbonds = append(matureUnbonds, timeslice.Pairs...) + if nonMature := t.After(currTime); nonMature { + return matureUnbonds, nil + } + pairs := unbondingDelegations[key] + matureUnbonds = append(matureUnbonds, pairs...) - if err = store.Delete(unbondingTimesliceIterator.Key()); err != nil { + err = store.Delete(types.GetUnbondingDelegationTimeKey(t)) + if err != nil { return matureUnbonds, err } + if k.cache != nil { + k.cache.DeleteUnbondingDelegationQueueEntry(key) + } } return matureUnbonds, nil } +// GetUBDs gets unbonding delegations from the cache or the store +func (k Keeper) GetUBDs(ctx context.Context, endTime time.Time) (map[string][]types.DVPair, error) { + if k.cache != nil { + pairs, err := k.cache.GetUnbondingDelegationsQueue(ctx) + if err == nil { + return pairs, nil + } + k.Logger(ctx).Error("GetUBDs from cache failed. Error: %s", err) + } + return k.GetUnbondingDelegationsQueueFromStore(ctx, endTime) +} + +// GetAllUnbondingDelegationsQueueFromStore gets unbonding delegations from the store +func (k Keeper) GetAllUnbondingDelegationsQueueFromStore(ctx context.Context) (map[string][]types.DVPair, error) { + iterator, err := k.UBDQueueIteratorAll(ctx) + if err != nil { + return nil, err + } + defer iterator.Close() + ubds, err := k.getUnbondingDelegationsFromIterator(iterator) + if err != nil { + return nil, err + } + + return ubds, nil +} + +// GetUnbondingDelegationsQueueFromStore gets unbonding delegations from the store for a given time. +func (k Keeper) GetUnbondingDelegationsQueueFromStore(ctx context.Context, endTime time.Time) (map[string][]types.DVPair, error) { + iterator, err := k.UBDQueueIterator(ctx, endTime) + if err != nil { + return nil, err + } + defer iterator.Close() + ubds, err := k.getUnbondingDelegationsFromIterator(iterator) + if err != nil { + return nil, err + } + + return ubds, nil +} + +// getUnbondingDelegationsFromIterator gets unbonding delegations from the iterator. +func (k Keeper) getUnbondingDelegationsFromIterator(iterator corestore.Iterator) (map[string][]types.DVPair, error) { + unbondingDelegations := make(map[string][]types.DVPair) + + for ; iterator.Valid(); iterator.Next() { + timeslice := types.DVPairs{} + value := iterator.Value() + if err := k.cdc.Unmarshal(value, ×lice); err != nil { + return nil, err + } + + t, err := types.ParseUnbondingDelegationTimeKey(iterator.Key()) + if err != nil { + return nil, err + } + unbondingDelegations[sdk.FormatTimeString(t)] = timeslice.Pairs + } + return unbondingDelegations, nil +} + // GetRedelegations returns a given amount of all the delegator redelegations. func (k Keeper) GetRedelegations(ctx context.Context, delegator sdk.AccAddress, maxRetrieve uint16) (redelegations []types.Redelegation, err error) { redelegations = make([]types.Redelegation, maxRetrieve) @@ -737,6 +836,14 @@ func (k Keeper) RemoveRedelegation(ctx context.Context, red types.Redelegation) // timeslice is a slice of DVVTriplets corresponding to redelegations that // expire at a certain time. func (k Keeper) GetRedelegationQueueTimeSlice(ctx context.Context, timestamp time.Time) (dvvTriplets []types.DVVTriplet, err error) { + if k.cache != nil { + cachedTriplets, err := k.cache.GetRedelegationsQueueEntry(ctx, timestamp) + if err == nil { + return cachedTriplets, nil + } + k.Logger(ctx).Error("GetRedelegationQueueTimeSlice from cache failed. Error: %s", err) + } + store := k.storeService.OpenKVStore(ctx) bz, err := store.Get(types.GetRedelegationTimeKey(timestamp)) if err != nil { @@ -763,7 +870,18 @@ func (k Keeper) SetRedelegationQueueTimeSlice(ctx context.Context, timestamp tim if err != nil { return err } - return store.Set(types.GetRedelegationTimeKey(timestamp), bz) + err = store.Set(types.GetRedelegationTimeKey(timestamp), bz) + if err != nil { + return err + } + + if k.cache != nil { + err = k.cache.SetRedelegationsQueueEntry(ctx, sdk.FormatTimeString(timestamp), keys) + if err != nil { + k.Logger(ctx).Error("SetRedelegationQueueTimeSlice from cache failed. Error: %s", err) + } + } + return nil } // InsertRedelegationQueue insert an redelegation delegation to the appropriate @@ -794,32 +912,51 @@ func (k Keeper) RedelegationQueueIterator(ctx context.Context, endTime time.Time return store.Iterator(types.RedelegationQueueKey, storetypes.InclusiveEndBytes(types.GetRedelegationTimeKey(endTime))) } -// DequeueAllMatureRedelegationQueue returns a concatenated list of all the -// timeslices inclusively previous to currTime, and deletes the timeslices from -// the queue. -func (k Keeper) DequeueAllMatureRedelegationQueue(ctx context.Context, currTime time.Time) (matureRedelegations []types.DVVTriplet, err error) { +// RedelegationQueueIteratorAll returns all the redelegation queue timeslices +func (k Keeper) RedelegationQueueIteratorAll(ctx context.Context) (storetypes.Iterator, error) { store := k.storeService.OpenKVStore(ctx) + return store.Iterator(types.RedelegationQueueKey, storetypes.PrefixEndBytes(types.RedelegationQueueKey)) +} - // gets an iterator for all timeslices from time 0 until the current Blockheader time - sdkCtx := sdk.UnwrapSDKContext(ctx) - redelegationTimesliceIterator, err := k.RedelegationQueueIterator(ctx, sdkCtx.HeaderInfo().Time) +// DequeueAllMatureRedelegationQueue returns a concatenated list of all the +// timeslices, and deletes the matured timeslices from the queue. +func (k Keeper) DequeueAllMatureRedelegationQueue(ctx context.Context, currTime time.Time) (matureRedelegations []types.DVVTriplet, err error) { + redelegations, err := k.GetPendingRedelegations(ctx, currTime) if err != nil { return nil, err } - defer redelegationTimesliceIterator.Close() - for ; redelegationTimesliceIterator.Valid(); redelegationTimesliceIterator.Next() { - timeslice := types.DVVTriplets{} - value := redelegationTimesliceIterator.Value() - if err = k.cdc.Unmarshal(value, ×lice); err != nil { + keys := make([]string, 0, len(redelegations)) + + for key := range redelegations { + keys = append(keys, key) + } + + types.SortTimestampsByAscendingOrder(keys) + + store := k.storeService.OpenKVStore(ctx) + + for _, key := range keys { + t, err := sdk.ParseTime(key) + if err != nil { return nil, err } - matureRedelegations = append(matureRedelegations, timeslice.Triplets...) + if nonMature := t.After(currTime); nonMature { + return matureRedelegations, nil + } + + triplets := redelegations[key] + matureRedelegations = append(matureRedelegations, triplets...) - if err = store.Delete(redelegationTimesliceIterator.Key()); err != nil { + err = store.Delete(types.GetRedelegationTimeKey(t)) + if err != nil { return nil, err } + + if k.cache != nil { + k.cache.DeleteRedelegationsQueueEntry(key) + } } return matureRedelegations, nil @@ -1346,3 +1483,66 @@ func (k Keeper) ValidateUnbondAmount( return shares, nil } + +// GetPendingRedelegations gets pending redelegations from the cache or the store +func (k Keeper) GetPendingRedelegations(ctx context.Context, currTime time.Time) (map[string][]types.DVVTriplet, error) { + if k.cache != nil { + redelegations, err := k.cache.GetRedelegationsQueue(ctx) + if err == nil { + return redelegations, nil + } + k.Logger(ctx).Error("GetPendingRedelegations from cache failed. Error: %s", err) + } + return k.GetRedelegationsQueueFromStore(ctx, currTime) +} + +// GetAllRedelegationsQueueFromStore gets redelegations from the store +func (k Keeper) GetAllRedelegationsQueueFromStore(ctx context.Context) (map[string][]types.DVVTriplet, error) { + iterator, err := k.RedelegationQueueIteratorAll(ctx) + if err != nil { + return nil, err + } + defer iterator.Close() + redelgations, err := k.getRedelegationsFromIterator(iterator) + if err != nil { + return nil, err + } + + return redelgations, nil +} + +// GetRedelegationsQueueFromStore gets redelegations from the store for a given time. +func (k Keeper) GetRedelegationsQueueFromStore(ctx context.Context, endTime time.Time) (map[string][]types.DVVTriplet, error) { + iterator, err := k.RedelegationQueueIterator(ctx, endTime) + if err != nil { + return nil, err + } + defer iterator.Close() + redelgations, err := k.getRedelegationsFromIterator(iterator) + if err != nil { + return nil, err + } + + return redelgations, nil +} + +// getRedelegationsFromIterator gets redelegations from the iterator. +func (k Keeper) getRedelegationsFromIterator(iterator corestore.Iterator) (map[string][]types.DVVTriplet, error) { + redelegations := make(map[string][]types.DVVTriplet) + + for ; iterator.Valid(); iterator.Next() { + timeslice := types.DVVTriplets{} + value := iterator.Value() + if err := k.cdc.Unmarshal(value, ×lice); err != nil { + return nil, err + } + + t, err := types.ParseRedelegationTimeKey(iterator.Key()) + if err != nil { + return nil, err + } + redelegations[sdk.FormatTimeString(t)] = timeslice.Triplets + } + + return redelegations, nil +} diff --git a/x/staking/keeper/delegation_test.go b/x/staking/keeper/delegation_test.go index 97e1e5afc629..1ad4063ddf0c 100644 --- a/x/staking/keeper/delegation_test.go +++ b/x/staking/keeper/delegation_test.go @@ -3,13 +3,21 @@ package keeper_test import ( "time" + cmtproto "github.com/cometbft/cometbft/proto/tendermint/types" + cmttime "github.com/cometbft/cometbft/types/time" "github.com/golang/mock/gomock" "cosmossdk.io/math" + storetypes "cosmossdk.io/store/types" "github.com/cosmos/cosmos-sdk/codec/address" + "github.com/cosmos/cosmos-sdk/runtime" + sdktestutil "github.com/cosmos/cosmos-sdk/testutil" simtestutil "github.com/cosmos/cosmos-sdk/testutil/sims" sdk "github.com/cosmos/cosmos-sdk/types" + moduletestutil "github.com/cosmos/cosmos-sdk/types/module/testutil" + authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" + govtypes "github.com/cosmos/cosmos-sdk/x/gov/types" stakingkeeper "github.com/cosmos/cosmos-sdk/x/staking/keeper" "github.com/cosmos/cosmos-sdk/x/staking/testutil" stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" @@ -1160,3 +1168,1370 @@ func (s *KeeperTestSuite) TestSetUnbondingDelegationEntry() { require.Equal(creationHeight, resUnbonding.Entries[0].CreationHeight) require.Equal(newCreationHeight, resUnbonding.Entries[1].CreationHeight) } + +func (s *KeeperTestSuite) TestGetUBDQueueTimeSlice() { + testCases := []struct { + name string + maxCacheSize int + description string + }{ + { + name: "cache size < 0 (cache disabled)", + maxCacheSize: -1, + description: "should always read from store when cache is not initialized", + }, + { + name: "cache size = 0 (unlimited cache)", + maxCacheSize: 0, + description: "should use unlimited cache with no size restrictions", + }, + { + name: "cache size > unbonding delegation entries", + maxCacheSize: 10, + description: "should use cache when cache is large enough", + }, + { + name: "cache size == unbonding delegation entries", + maxCacheSize: 3, + description: "should use cache when cache size matches entries", + }, + { + name: "cache size < unbonding delegation entries", + maxCacheSize: 1, + description: "should fallback to store when cache size is exceeded", + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + key := storetypes.NewKVStoreKey(stakingtypes.StoreKey) + storeService := runtime.NewKVStoreService(key) + testCtx := sdktestutil.DefaultContextWithDB(s.T(), key, storetypes.NewTransientStoreKey("transient_test")) + ctx := testCtx.Ctx.WithBlockHeader(cmtproto.Header{Time: cmttime.Now()}) + encCfg := moduletestutil.MakeTestEncodingConfig() + + ctrl := gomock.NewController(s.T()) + accountKeeper := testutil.NewMockAccountKeeper(ctrl) + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.BondedPoolName).Return(bondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.NotBondedPoolName).Return(notBondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().AddressCodec().Return(address.NewBech32Codec("cosmos")).AnyTimes() + + bankKeeper := testutil.NewMockBankKeeper(ctrl) + + keeper := stakingkeeper.NewKeeper( + encCfg.Codec, + storeService, + accountKeeper, + bankKeeper, + authtypes.NewModuleAddress(govtypes.ModuleName).String(), + address.NewBech32Codec("cosmosvaloper"), + address.NewBech32Codec("cosmosvalcons"), + tc.maxCacheSize, + ) + s.Require().NoError(keeper.SetParams(ctx, stakingtypes.DefaultParams())) + + blockTime := time.Now().UTC() + blockHeight := int64(1000) + ctx = ctx.WithBlockHeight(blockHeight).WithBlockTime(blockTime) + + delAddrs, valAddrs := createValAddrs(3) + + // Create multiple unbonding delegations with different completion times + time1 := blockTime + ubd1 := stakingtypes.NewUnbondingDelegation( + delAddrs[0], + valAddrs[0], + blockHeight, + time1, + math.NewInt(10), + address.NewBech32Codec("cosmosvaloper"), + address.NewBech32Codec("cosmos"), + ) + s.Require().NoError(keeper.InsertUBDQueue(ctx, ubd1, time1)) + + time2 := blockTime.Add(1 * time.Hour) + ubd2 := stakingtypes.NewUnbondingDelegation( + delAddrs[1], + valAddrs[1], + blockHeight, + time2, + math.NewInt(20), + address.NewBech32Codec("cosmosvaloper"), + address.NewBech32Codec("cosmos"), + ) + s.Require().NoError(keeper.InsertUBDQueue(ctx, ubd2, time2)) + + time3 := blockTime.Add(2 * time.Hour) + ubd3 := stakingtypes.NewUnbondingDelegation( + delAddrs[2], + valAddrs[2], + blockHeight, + time3, + math.NewInt(30), + address.NewBech32Codec("cosmosvaloper"), + address.NewBech32Codec("cosmos"), + ) + s.Require().NoError(keeper.InsertUBDQueue(ctx, ubd3, time3)) + + // Test GetUBDQueueTimeSlice for time1 + slice1, err := keeper.GetUBDQueueTimeSlice(ctx, time1) + s.Require().NoError(err) + s.Require().Equal(1, len(slice1), "should have 1 entry at time1") + s.Require().Equal(ubd1.DelegatorAddress, slice1[0].DelegatorAddress) + s.Require().Equal(ubd1.ValidatorAddress, slice1[0].ValidatorAddress) + + // Test GetUBDQueueTimeSlice for time2 + slice2, err := keeper.GetUBDQueueTimeSlice(ctx, time2) + s.Require().NoError(err) + s.Require().Equal(1, len(slice2), "should have 1 entry at time2") + s.Require().Equal(ubd2.DelegatorAddress, slice2[0].DelegatorAddress) + s.Require().Equal(ubd2.ValidatorAddress, slice2[0].ValidatorAddress) + + // Test GetUBDQueueTimeSlice for time3 + slice3, err := keeper.GetUBDQueueTimeSlice(ctx, time3) + s.Require().NoError(err) + s.Require().Equal(1, len(slice3), "should have 1 entry at time3") + s.Require().Equal(ubd3.DelegatorAddress, slice3[0].DelegatorAddress) + s.Require().Equal(ubd3.ValidatorAddress, slice3[0].ValidatorAddress) + + // Test calling again to verify cache consistency + slice1Again, err := keeper.GetUBDQueueTimeSlice(ctx, time1) + s.Require().NoError(err) + s.Require().Equal(len(slice1), len(slice1Again), "repeated call should return same number of entries") + s.Require().Equal(slice1[0].DelegatorAddress, slice1Again[0].DelegatorAddress) + + // Test for non-existent time (should return empty slice) + emptyTime := blockTime.Add(-1 * time.Hour) + emptySlice, err := keeper.GetUBDQueueTimeSlice(ctx, emptyTime) + s.Require().NoError(err) + s.Require().Equal(0, len(emptySlice), "should have 0 entries at non-existent time") + }) + } +} + +func (s *KeeperTestSuite) TestGetAllUnbondingDelegations() { + testCases := []struct { + name string + maxCacheSize int + description string + }{ + { + name: "cache size < 0 (cache disabled)", + maxCacheSize: -1, + description: "should always read from store when cache is not initialized", + }, + { + name: "cache size = 0 (unlimited cache)", + maxCacheSize: 0, + description: "should use unlimited cache with no size restrictions", + }, + { + name: "cache size > unbonding delegation entries", + maxCacheSize: 10, + description: "should use cache when cache is large enough", + }, + { + name: "cache size == unbonding delegation entries", + maxCacheSize: 2, + description: "should use cache when cache size matches entries", + }, + { + name: "cache size < unbonding delegation entries", + maxCacheSize: 1, + description: "should fallback to store when cache size is exceeded", + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + key := storetypes.NewKVStoreKey(stakingtypes.StoreKey) + storeService := runtime.NewKVStoreService(key) + testCtx := sdktestutil.DefaultContextWithDB(s.T(), key, storetypes.NewTransientStoreKey("transient_test")) + ctx := testCtx.Ctx.WithBlockHeader(cmtproto.Header{Time: cmttime.Now()}) + encCfg := moduletestutil.MakeTestEncodingConfig() + + ctrl := gomock.NewController(s.T()) + accountKeeper := testutil.NewMockAccountKeeper(ctrl) + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.BondedPoolName).Return(bondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.NotBondedPoolName).Return(notBondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().AddressCodec().Return(address.NewBech32Codec("cosmos")).AnyTimes() + + bankKeeper := testutil.NewMockBankKeeper(ctrl) + + keeper := stakingkeeper.NewKeeper( + encCfg.Codec, + storeService, + accountKeeper, + bankKeeper, + authtypes.NewModuleAddress(govtypes.ModuleName).String(), + address.NewBech32Codec("cosmosvaloper"), + address.NewBech32Codec("cosmosvalcons"), + tc.maxCacheSize, + ) + s.Require().NoError(keeper.SetParams(ctx, stakingtypes.DefaultParams())) + + blockTime := time.Now().UTC() + blockHeight := int64(1000) + ctx = ctx.WithBlockHeight(blockHeight).WithBlockTime(blockTime) + + delAddrs, valAddrs := createValAddrs(2) + + // insert unbonding delegation + ubd := stakingtypes.NewUnbondingDelegation( + delAddrs[0], + valAddrs[0], + blockHeight, + blockTime, + math.NewInt(10), + address.NewBech32Codec("cosmosvaloper"), + address.NewBech32Codec("cosmos"), + ) + + t := blockTime + s.Require().NoError(keeper.InsertUBDQueue(ctx, ubd, t)) + + // add another unbonding delegation + ubd1 := stakingtypes.NewUnbondingDelegation( + delAddrs[1], + valAddrs[1], + blockHeight, + blockTime, + math.NewInt(10), + address.NewBech32Codec("cosmosvaloper"), + address.NewBech32Codec("cosmos"), + ) + t1 := blockTime.Add(-1 * time.Minute) + s.Require().NoError(keeper.InsertUBDQueue(ctx, ubd1, t1)) + + // get all unbonding delegations should return the inserted unbonding delegations + unbondingDelegations, err := keeper.GetUBDs(ctx, blockTime) + s.Require().NoError(err) + s.Require().Equal(2, len(unbondingDelegations)) + s.Require().Equal(ubd.DelegatorAddress, unbondingDelegations[sdk.FormatTimeString(t)][0].DelegatorAddress) + s.Require().Equal(ubd.ValidatorAddress, unbondingDelegations[sdk.FormatTimeString(t)][0].ValidatorAddress) + s.Require().Equal(ubd1.DelegatorAddress, unbondingDelegations[sdk.FormatTimeString(t1)][0].DelegatorAddress) + s.Require().Equal(ubd1.ValidatorAddress, unbondingDelegations[sdk.FormatTimeString(t1)][0].ValidatorAddress) + + // Test calling again to verify cache consistency + unbondingDelegations2, err := keeper.GetUBDs(ctx, blockTime) + s.Require().NoError(err) + s.Require().Equal(len(unbondingDelegations), len(unbondingDelegations2), "repeated call should return same number of entries") + }) + } +} + +func (s *KeeperTestSuite) TestInsertUBDQueue() { + testCases := []struct { + name string + maxCacheSize int + description string + }{ + { + name: "cache size < 0 (cache disabled)", + maxCacheSize: -1, + description: "should always write to store when cache is not initialized", + }, + { + name: "cache size = 0 (unlimited cache)", + maxCacheSize: 0, + description: "should use unlimited cache with no size restrictions", + }, + { + name: "cache size > unbonding delegation entries", + maxCacheSize: 10, + description: "should use cache when cache is large enough", + }, + { + name: "cache size == unbonding delegation entries", + maxCacheSize: 2, + description: "should use cache when cache size matches entries", + }, + { + name: "cache size < unbonding delegation entries", + maxCacheSize: 1, + description: "should fallback to store when cache size is exceeded", + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + key := storetypes.NewKVStoreKey(stakingtypes.StoreKey) + storeService := runtime.NewKVStoreService(key) + testCtx := sdktestutil.DefaultContextWithDB(s.T(), key, storetypes.NewTransientStoreKey("transient_test")) + ctx := testCtx.Ctx.WithBlockHeader(cmtproto.Header{Time: cmttime.Now()}) + encCfg := moduletestutil.MakeTestEncodingConfig() + + ctrl := gomock.NewController(s.T()) + accountKeeper := testutil.NewMockAccountKeeper(ctrl) + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.BondedPoolName).Return(bondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.NotBondedPoolName).Return(notBondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().AddressCodec().Return(address.NewBech32Codec("cosmos")).AnyTimes() + + bankKeeper := testutil.NewMockBankKeeper(ctrl) + + keeper := stakingkeeper.NewKeeper( + encCfg.Codec, + storeService, + accountKeeper, + bankKeeper, + authtypes.NewModuleAddress(govtypes.ModuleName).String(), + address.NewBech32Codec("cosmosvaloper"), + address.NewBech32Codec("cosmosvalcons"), + tc.maxCacheSize, + ) + s.Require().NoError(keeper.SetParams(ctx, stakingtypes.DefaultParams())) + + blockTime := time.Now().UTC() + blockHeight := int64(1000) + ctx = ctx.WithBlockHeight(blockHeight).WithBlockTime(blockTime) + + iterator, err := keeper.UBDQueueIterator(ctx, blockTime) + s.Require().NoError(err) + defer iterator.Close() + count := 0 + for ; iterator.Valid(); iterator.Next() { + count++ + } + // no unbonding delegations in the queue initially + s.Require().Equal(0, count) + + delAddrs, valAddrs := createValAddrs(3) + + // insert unbonding delegation + ubd := stakingtypes.NewUnbondingDelegation( + delAddrs[0], + valAddrs[0], + blockHeight, + blockTime, + math.NewInt(10), + address.NewBech32Codec("cosmosvaloper"), + address.NewBech32Codec("cosmos"), + ) + + t := blockTime + s.Require().NoError(keeper.InsertUBDQueue(ctx, ubd, t)) + + // insert another unbonding delegation + ubd1 := stakingtypes.NewUnbondingDelegation( + delAddrs[1], + valAddrs[1], + blockHeight, + blockTime, + math.NewInt(10), + address.NewBech32Codec("cosmosvaloper"), + address.NewBech32Codec("cosmos"), + ) + + s.Require().NoError(keeper.InsertUBDQueue(ctx, ubd1, t)) + + iterator1, err := keeper.UBDQueueIterator(ctx, blockTime) + s.Require().NoError(err) + defer iterator1.Close() + count1 := 0 + for ; iterator1.Valid(); iterator1.Next() { + count1++ + } + + // unbonding delegation should be retrieved + // count 1 due to same unbonding time + s.Require().Equal(1, count1) + + // Verify GetUBDQueueTimeSlice returns the correct unbonding delegations after insertion + ubds, err := keeper.GetUBDQueueTimeSlice(ctx, blockTime) + s.Require().NoError(err) + s.Require().Equal(2, len(ubds), "should have 2 unbonding delegations at same time") + + // insert unbonding delegation with different unbonding time and height + ubd2 := stakingtypes.NewUnbondingDelegation( + delAddrs[2], + valAddrs[2], + blockHeight, + blockTime, + math.NewInt(10), + address.NewBech32Codec("cosmosvaloper"), + address.NewBech32Codec("cosmos"), + ) + t1 := blockTime.Add(-1 * time.Minute) + s.Require().NoError(keeper.InsertUBDQueue(ctx, ubd2, t1)) + + iterator2, err := keeper.UBDQueueIterator(ctx, blockTime) + s.Require().NoError(err) + defer iterator2.Close() + count2 := 0 + for ; iterator2.Valid(); iterator2.Next() { + count2++ + } + + // unbonding delegation should be retrieved + s.Require().Equal(2, count2) + + // Verify the new unbonding delegation was inserted at the correct time + ubds2, err := keeper.GetUBDQueueTimeSlice(ctx, t1) + s.Require().NoError(err) + s.Require().Equal(1, len(ubds2), "should have 1 unbonding delegation at different time") + }) + } +} + +func (s *KeeperTestSuite) TestDequeueAllMatureUBDQueue() { + testCases := []struct { + name string + maxCacheSize int + numUnbondingDelegations int + }{ + { + name: "cache size < 0 (cache disabled)", + maxCacheSize: -1, + numUnbondingDelegations: 3, + }, + { + name: "cache size = 0 (unlimited cache)", + maxCacheSize: 0, + numUnbondingDelegations: 3, + }, + { + name: "cache size > unbonding delegations", + maxCacheSize: 5, + numUnbondingDelegations: 2, + }, + { + name: "cache size == unbonding delegations", + maxCacheSize: 2, + numUnbondingDelegations: 2, + }, + { + name: "cache size < unbonding delegations", + maxCacheSize: 1, + numUnbondingDelegations: 3, + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + key := storetypes.NewKVStoreKey(stakingtypes.StoreKey) + storeService := runtime.NewKVStoreService(key) + testCtx := sdktestutil.DefaultContextWithDB(s.T(), key, storetypes.NewTransientStoreKey("transient_test")) + ctx := testCtx.Ctx.WithBlockHeader(cmtproto.Header{Time: cmttime.Now()}) + encCfg := moduletestutil.MakeTestEncodingConfig() + + ctrl := gomock.NewController(s.T()) + accountKeeper := testutil.NewMockAccountKeeper(ctrl) + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.BondedPoolName).Return(bondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.NotBondedPoolName).Return(notBondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().AddressCodec().Return(address.NewBech32Codec("cosmos")).AnyTimes() + + bankKeeper := testutil.NewMockBankKeeper(ctrl) + bankKeeper.EXPECT().DelegateCoinsFromAccountToModule(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + bankKeeper.EXPECT().SendCoinsFromModuleToModule(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + bankKeeper.EXPECT().UndelegateCoinsFromModuleToAccount(gomock.Any(), stakingtypes.NotBondedPoolName, gomock.Any(), gomock.Any()).AnyTimes() + + // Initialize keeper with specific cache size + keeper := stakingkeeper.NewKeeper( + encCfg.Codec, + storeService, + accountKeeper, + bankKeeper, + authtypes.NewModuleAddress(govtypes.ModuleName).String(), + address.NewBech32Codec("cosmosvaloper"), + address.NewBech32Codec("cosmosvalcons"), + tc.maxCacheSize, + ) + params := stakingtypes.DefaultParams() + params.UnbondingTime = 1 * time.Second + s.Require().NoError(keeper.SetParams(ctx, params)) + + blockTime := time.Now().UTC() + ctx = ctx.WithBlockTime(blockTime) + + // Create validator + valAddr := sdk.ValAddress(PKs[0].Address()) + validator := testutil.NewValidator(s.T(), valAddr, PKs[0]) + validator, _ = validator.AddTokensFromDel(keeper.TokensFromConsensusPower(ctx, 100)) + validator = stakingkeeper.TestingUpdateValidator(keeper, ctx, validator, true) + + // Create multiple unbonding delegations + delAddrs, _ := createValAddrs(tc.numUnbondingDelegations) + for i := 0; i < tc.numUnbondingDelegations; i++ { + // Delegate + bondAmt := keeper.TokensFromConsensusPower(ctx, 10) + _, err := keeper.Delegate(ctx, delAddrs[i], bondAmt, stakingtypes.Unbonded, validator, true) + s.Require().NoError(err) + + // Undelegate + _, _, err = keeper.Undelegate(ctx, delAddrs[i], valAddr, math.LegacyNewDec(5)) + s.Require().NoError(err) + } + + // Verify unbonding delegations were created + for i := 0; i < tc.numUnbondingDelegations; i++ { + _, err := keeper.GetUnbondingDelegation(ctx, delAddrs[i], valAddr) + s.Require().NoError(err) + } + + // Fast-forward time to maturity + ctx = ctx.WithBlockTime(blockTime.Add(params.UnbondingTime)) + + // Verify GetUBDs returns the expected number of unbonding delegations + allUBDs, err := keeper.GetUBDs(ctx, ctx.BlockTime()) + s.Require().NoError(err) + s.Require().NotEmpty(allUBDs) + + // Verify GetUBDQueueTimeSlice returns the expected number of unbonding delegations + // In this case, it should return all unbonding delegations as all unbonding delegations are at the same time. + ubds, err := keeper.GetUBDQueueTimeSlice(ctx, ctx.BlockTime()) + s.Require().NoError(err) + s.Require().Equal(tc.numUnbondingDelegations, len(ubds)) + + // Dequeue and complete all mature unbonding delegations + matureUnbonds, err := keeper.DequeueAllMatureUBDQueue(ctx, ctx.BlockTime()) + s.Require().NoError(err) + s.Require().Equal(tc.numUnbondingDelegations, len(matureUnbonds), "all unbonding delegations should be mature") + + // Complete the unbonding delegations + for _, dvPair := range matureUnbonds { + delAddr, err := accountKeeper.AddressCodec().StringToBytes(dvPair.DelegatorAddress) + s.Require().NoError(err) + valAddr, err := keeper.ValidatorAddressCodec().StringToBytes(dvPair.ValidatorAddress) + s.Require().NoError(err) + _, err = keeper.CompleteUnbonding(ctx, delAddr, valAddr) + s.Require().NoError(err) + } + + // Verify all unbonding delegations were completed (removed from store) + for i := 0; i < tc.numUnbondingDelegations; i++ { + _, err := keeper.GetUnbondingDelegation(ctx, delAddrs[i], valAddr) + s.Require().ErrorIs(err, stakingtypes.ErrNoUnbondingDelegation, "unbonding delegation should be completed and removed") + } + }) + } +} + +func (s *KeeperTestSuite) TestUnbondingDelegationQueueCacheRecovery() { + // This test verifies that when the cache is initially too small (exceeded), + // and then entries are dequeued, the cache can recover and be used again + // Cache size is based on the number of unique timestamps (keys), not individual entries + key := storetypes.NewKVStoreKey(stakingtypes.StoreKey) + storeService := runtime.NewKVStoreService(key) + testCtx := sdktestutil.DefaultContextWithDB(s.T(), key, storetypes.NewTransientStoreKey("transient_test")) + ctx := testCtx.Ctx.WithBlockHeader(cmtproto.Header{Time: cmttime.Now()}) + encCfg := moduletestutil.MakeTestEncodingConfig() + + ctrl := gomock.NewController(s.T()) + accountKeeper := testutil.NewMockAccountKeeper(ctrl) + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.BondedPoolName).Return(bondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.NotBondedPoolName).Return(notBondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().AddressCodec().Return(address.NewBech32Codec("cosmos")).AnyTimes() + + bankKeeper := testutil.NewMockBankKeeper(ctrl) + bankKeeper.EXPECT().DelegateCoinsFromAccountToModule(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + bankKeeper.EXPECT().SendCoinsFromModuleToModule(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + bankKeeper.EXPECT().UndelegateCoinsFromModuleToAccount(gomock.Any(), stakingtypes.NotBondedPoolName, gomock.Any(), gomock.Any()).AnyTimes() + + // Initialize keeper with small cache size (2 timestamps) + maxCacheSize := 2 + keeper := stakingkeeper.NewKeeper( + encCfg.Codec, + storeService, + accountKeeper, + bankKeeper, + authtypes.NewModuleAddress(govtypes.ModuleName).String(), + address.NewBech32Codec("cosmosvaloper"), + address.NewBech32Codec("cosmosvalcons"), + maxCacheSize, + ) + params := stakingtypes.DefaultParams() + params.UnbondingTime = 1 * time.Hour // Use a long enough time (1 hr) so we can create different timestamps + s.Require().NoError(keeper.SetParams(ctx, params)) + + baseTime := time.Now().UTC() + ctx = ctx.WithBlockTime(baseTime) + + // Create validator + valAddr := sdk.ValAddress(PKs[0].Address()) + validator := testutil.NewValidator(s.T(), valAddr, PKs[0]) + validator, _ = validator.AddTokensFromDel(keeper.TokensFromConsensusPower(ctx, 100)) + validator = stakingkeeper.TestingUpdateValidator(keeper, ctx, validator, true) + + // Create unbonding delegations at 5 different timestamps (exceeds cache size of 2) + numTimestamps := 5 + delAddrs, _ := createValAddrs(numTimestamps) + completionTimes := make([]time.Time, numTimestamps) + + for i := 0; i < numTimestamps; i++ { + // Set different block times to create different completion timestamps + currentTime := baseTime.Add(time.Duration(i) * time.Second) + ctx = ctx.WithBlockTime(currentTime) + completionTimes[i] = currentTime.Add(params.UnbondingTime) + + // Delegate + bondAmt := keeper.TokensFromConsensusPower(ctx, 10) + _, err := keeper.Delegate(ctx, delAddrs[i], bondAmt, stakingtypes.Unbonded, validator, true) + s.Require().NoError(err) + + // Undelegate (will create unbonding delegation with unique completion time) + _, _, err = keeper.Undelegate(ctx, delAddrs[i], valAddr, math.LegacyNewDec(5)) + s.Require().NoError(err) + } + + // Verify all unbonding delegations were created + for i := 0; i < numTimestamps; i++ { + _, err := keeper.GetUnbondingDelegation(ctx, delAddrs[i], valAddr) + s.Require().NoError(err) + } + + // At this point, cache should be exceeded (5 timestamps > maxCacheSize of 2) + // GetUBDs should still work, but will read from store instead of cache + ctx = ctx.WithBlockTime(completionTimes[numTimestamps-1]) + allUBDs, err := keeper.GetUBDs(ctx, ctx.BlockTime()) + s.Require().NoError(err) + s.Require().Equal(5, len(allUBDs), "should have 5 different timestamps") + + // Fast-forward time to mature the first 3 timestamps and dequeue them + ctx = ctx.WithBlockTime(completionTimes[2]) + matureUnbonds, err := keeper.DequeueAllMatureUBDQueue(ctx, ctx.BlockTime()) + s.Require().NoError(err) + s.Require().Equal(3, len(matureUnbonds), "should dequeue 3 unbonding delegations from 3 timestamps") + + // Complete the first 3 unbonding delegations + for i := 0; i < 3; i++ { + _, err = keeper.CompleteUnbonding(ctx, delAddrs[i], valAddr) + s.Require().NoError(err) + } + + // Verify the first 3 were removed + for i := 0; i < 3; i++ { + _, err := keeper.GetUnbondingDelegation(ctx, delAddrs[i], valAddr) + s.Require().ErrorIs(err, stakingtypes.ErrNoUnbondingDelegation, "unbonding delegation should be completed and removed") + } + + // Verify the last 2 still exist + for i := 3; i < numTimestamps; i++ { + _, err := keeper.GetUnbondingDelegation(ctx, delAddrs[i], valAddr) + s.Require().NoError(err, "unbonding delegation should still exist") + } + + // Now only 2 timestamps remain (completionTimes[3] and completionTimes[4]) + // This fits in the cache (2 timestamps == maxCacheSize) + // GetUBDs should now be able to use the cache + ctx = ctx.WithBlockTime(completionTimes[4]) + remainingUBDs, err := keeper.GetUBDs(ctx, ctx.BlockTime()) + s.Require().NoError(err) + s.Require().Equal(2, len(remainingUBDs), "should have 2 timestamps in cache") + + // Dequeue the remaining 2 + finalMatureUnbonds, err := keeper.DequeueAllMatureUBDQueue(ctx, ctx.BlockTime()) + s.Require().NoError(err) + s.Require().Equal(2, len(finalMatureUnbonds), "should have 2 mature unbonding delegations") + + // Complete them + for i := 3; i < numTimestamps; i++ { + _, err = keeper.CompleteUnbonding(ctx, delAddrs[i], valAddr) + s.Require().NoError(err) + } + + // Verify all unbonding delegations are now completed + for i := 3; i < numTimestamps; i++ { + _, err := keeper.GetUnbondingDelegation(ctx, delAddrs[i], valAddr) + s.Require().ErrorIs(err, stakingtypes.ErrNoUnbondingDelegation, "all unbonding delegations should be completed") + } +} + +func (s *KeeperTestSuite) TestGetAndParseUnbondingDelegationTimeKey() { + require := s.Require() + + blockTime := time.Now().UTC() + key := stakingtypes.GetUnbondingDelegationTimeKey(blockTime) + time, err := stakingtypes.ParseUnbondingDelegationTimeKey(key) + require.NoError(err) + require.Equal(blockTime, time) +} + +func (s *KeeperTestSuite) TestGetRedelegationQueueTimeSlice() { + testCases := []struct { + name string + maxCacheSize int + description string + }{ + { + name: "cache size < 0 (cache disabled)", + maxCacheSize: -1, + description: "should always read from store when cache is not initialized", + }, + { + name: "cache size = 0 (unlimited cache)", + maxCacheSize: 0, + description: "should use unlimited cache with no size restrictions", + }, + { + name: "cache size > redelegation entries", + maxCacheSize: 10, + description: "should use cache when cache is large enough", + }, + { + name: "cache size == redelegation entries", + maxCacheSize: 3, + description: "should use cache when cache size matches entries", + }, + { + name: "cache size < redelegation entries", + maxCacheSize: 1, + description: "should fallback to store when cache size is exceeded", + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + key := storetypes.NewKVStoreKey(stakingtypes.StoreKey) + storeService := runtime.NewKVStoreService(key) + testCtx := sdktestutil.DefaultContextWithDB(s.T(), key, storetypes.NewTransientStoreKey("transient_test")) + ctx := testCtx.Ctx.WithBlockHeader(cmtproto.Header{Time: cmttime.Now()}) + encCfg := moduletestutil.MakeTestEncodingConfig() + + ctrl := gomock.NewController(s.T()) + accountKeeper := testutil.NewMockAccountKeeper(ctrl) + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.BondedPoolName).Return(bondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.NotBondedPoolName).Return(notBondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().AddressCodec().Return(address.NewBech32Codec("cosmos")).AnyTimes() + + bankKeeper := testutil.NewMockBankKeeper(ctrl) + + keeper := stakingkeeper.NewKeeper( + encCfg.Codec, + storeService, + accountKeeper, + bankKeeper, + authtypes.NewModuleAddress(govtypes.ModuleName).String(), + address.NewBech32Codec("cosmosvaloper"), + address.NewBech32Codec("cosmosvalcons"), + tc.maxCacheSize, + ) + s.Require().NoError(keeper.SetParams(ctx, stakingtypes.DefaultParams())) + + blockTime := time.Now().UTC() + blockHeight := int64(1000) + ctx = ctx.WithBlockHeight(blockHeight).WithBlockTime(blockTime) + + delAddrs, valAddrs := createValAddrs(4) + + // Create multiple redelegations with different completion times + time1 := blockTime + red1 := stakingtypes.Redelegation{ + DelegatorAddress: delAddrs[0].String(), + ValidatorSrcAddress: valAddrs[0].String(), + ValidatorDstAddress: valAddrs[1].String(), + } + s.Require().NoError(keeper.InsertRedelegationQueue(ctx, red1, time1)) + + time2 := blockTime.Add(1 * time.Hour) + red2 := stakingtypes.Redelegation{ + DelegatorAddress: delAddrs[1].String(), + ValidatorSrcAddress: valAddrs[1].String(), + ValidatorDstAddress: valAddrs[2].String(), + } + s.Require().NoError(keeper.InsertRedelegationQueue(ctx, red2, time2)) + + time3 := blockTime.Add(2 * time.Hour) + red3 := stakingtypes.Redelegation{ + DelegatorAddress: delAddrs[2].String(), + ValidatorSrcAddress: valAddrs[2].String(), + ValidatorDstAddress: valAddrs[3].String(), + } + s.Require().NoError(keeper.InsertRedelegationQueue(ctx, red3, time3)) + + // Test GetRedelegationQueueTimeSlice for time1 + slice1, err := keeper.GetRedelegationQueueTimeSlice(ctx, time1) + s.Require().NoError(err) + s.Require().Equal(1, len(slice1), "should have 1 entry at time1") + s.Require().Equal(red1.DelegatorAddress, slice1[0].DelegatorAddress) + s.Require().Equal(red1.ValidatorSrcAddress, slice1[0].ValidatorSrcAddress) + s.Require().Equal(red1.ValidatorDstAddress, slice1[0].ValidatorDstAddress) + + // Test GetRedelegationQueueTimeSlice for time2 + slice2, err := keeper.GetRedelegationQueueTimeSlice(ctx, time2) + s.Require().NoError(err) + s.Require().Equal(1, len(slice2), "should have 1 entry at time2") + s.Require().Equal(red2.DelegatorAddress, slice2[0].DelegatorAddress) + s.Require().Equal(red2.ValidatorSrcAddress, slice2[0].ValidatorSrcAddress) + s.Require().Equal(red2.ValidatorDstAddress, slice2[0].ValidatorDstAddress) + + // Test GetRedelegationQueueTimeSlice for time3 + slice3, err := keeper.GetRedelegationQueueTimeSlice(ctx, time3) + s.Require().NoError(err) + s.Require().Equal(1, len(slice3), "should have 1 entry at time3") + s.Require().Equal(red3.DelegatorAddress, slice3[0].DelegatorAddress) + s.Require().Equal(red3.ValidatorSrcAddress, slice3[0].ValidatorSrcAddress) + s.Require().Equal(red3.ValidatorDstAddress, slice3[0].ValidatorDstAddress) + + // Test calling again to verify cache consistency + slice1Again, err := keeper.GetRedelegationQueueTimeSlice(ctx, time1) + s.Require().NoError(err) + s.Require().Equal(len(slice1), len(slice1Again), "repeated call should return same number of entries") + s.Require().Equal(slice1[0].DelegatorAddress, slice1Again[0].DelegatorAddress) + + // Test for non-existent time (should return empty slice) + emptyTime := blockTime.Add(-1 * time.Hour) + emptySlice, err := keeper.GetRedelegationQueueTimeSlice(ctx, emptyTime) + s.Require().NoError(err) + s.Require().Equal(0, len(emptySlice), "should have 0 entries at non-existent time") + }) + } +} + +func (s *KeeperTestSuite) TestGetPendingRedelegations() { + testCases := []struct { + name string + maxCacheSize int + description string + }{ + { + name: "cache size < 0 (cache disabled)", + maxCacheSize: -1, + description: "should always read from store when cache is not initialized", + }, + { + name: "cache size = 0 (unlimited cache)", + maxCacheSize: 0, + description: "should use unlimited cache with no size restrictions", + }, + { + name: "cache size > redelegation entries", + maxCacheSize: 10, + description: "should use cache when cache is large enough", + }, + { + name: "cache size == redelegation entries", + maxCacheSize: 2, + description: "should use cache when cache size matches entries", + }, + { + name: "cache size < redelegation entries", + maxCacheSize: 1, + description: "should fallback to store when cache size is exceeded", + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + key := storetypes.NewKVStoreKey(stakingtypes.StoreKey) + storeService := runtime.NewKVStoreService(key) + testCtx := sdktestutil.DefaultContextWithDB(s.T(), key, storetypes.NewTransientStoreKey("transient_test")) + ctx := testCtx.Ctx.WithBlockHeader(cmtproto.Header{Time: cmttime.Now()}) + encCfg := moduletestutil.MakeTestEncodingConfig() + + ctrl := gomock.NewController(s.T()) + accountKeeper := testutil.NewMockAccountKeeper(ctrl) + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.BondedPoolName).Return(bondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.NotBondedPoolName).Return(notBondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().AddressCodec().Return(address.NewBech32Codec("cosmos")).AnyTimes() + + bankKeeper := testutil.NewMockBankKeeper(ctrl) + + keeper := stakingkeeper.NewKeeper( + encCfg.Codec, + storeService, + accountKeeper, + bankKeeper, + authtypes.NewModuleAddress(govtypes.ModuleName).String(), + address.NewBech32Codec("cosmosvaloper"), + address.NewBech32Codec("cosmosvalcons"), + tc.maxCacheSize, + ) + s.Require().NoError(keeper.SetParams(ctx, stakingtypes.DefaultParams())) + + blockTime := time.Now().UTC() + blockHeight := int64(1000) + ctx = ctx.WithBlockHeight(blockHeight).WithBlockTime(blockTime) + + delAddrs, valAddrs := createValAddrs(2) + + // insert redelegation + red := stakingtypes.Redelegation{ + DelegatorAddress: delAddrs[0].String(), + ValidatorSrcAddress: valAddrs[0].String(), + ValidatorDstAddress: valAddrs[1].String(), + } + + t := blockTime + s.Require().NoError(keeper.InsertRedelegationQueue(ctx, red, t)) + + // add another redelegation + red1 := stakingtypes.Redelegation{ + DelegatorAddress: delAddrs[1].String(), + ValidatorSrcAddress: valAddrs[1].String(), + ValidatorDstAddress: valAddrs[0].String(), + } + t1 := blockTime.Add(-1 * time.Minute) + s.Require().NoError(keeper.InsertRedelegationQueue(ctx, red1, t1)) + + // get all redelegations should return the inserted redelegations + redelegations, err := keeper.GetPendingRedelegations(ctx, blockTime) + s.Require().NoError(err) + s.Require().Equal(2, len(redelegations)) + s.Require().Equal(red.DelegatorAddress, redelegations[sdk.FormatTimeString(t)][0].DelegatorAddress) + s.Require().Equal(red.ValidatorSrcAddress, redelegations[sdk.FormatTimeString(t)][0].ValidatorSrcAddress) + s.Require().Equal(red.ValidatorDstAddress, redelegations[sdk.FormatTimeString(t)][0].ValidatorDstAddress) + s.Require().Equal(red1.DelegatorAddress, redelegations[sdk.FormatTimeString(t1)][0].DelegatorAddress) + s.Require().Equal(red1.ValidatorSrcAddress, redelegations[sdk.FormatTimeString(t1)][0].ValidatorSrcAddress) + s.Require().Equal(red1.ValidatorDstAddress, redelegations[sdk.FormatTimeString(t1)][0].ValidatorDstAddress) + + // Test calling again to verify cache consistency + redelegations2, err := keeper.GetPendingRedelegations(ctx, blockTime) + s.Require().NoError(err) + s.Require().Equal(len(redelegations), len(redelegations2), "repeated call should return same number of entries") + }) + } +} + +func (s *KeeperTestSuite) TestInsertRedelegationQueue() { + testCases := []struct { + name string + maxCacheSize int + description string + }{ + { + name: "cache size < 0 (cache disabled)", + maxCacheSize: -1, + description: "should always write to store when cache is not initialized", + }, + { + name: "cache size = 0 (unlimited cache)", + maxCacheSize: 0, + description: "should use unlimited cache with no size restrictions", + }, + { + name: "cache size > redelegation entries", + maxCacheSize: 10, + description: "should use cache when cache is large enough", + }, + { + name: "cache size == redelegation entries", + maxCacheSize: 2, + description: "should use cache when cache size matches entries", + }, + { + name: "cache size < redelegation entries", + maxCacheSize: 1, + description: "should fallback to store when cache size is exceeded", + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + key := storetypes.NewKVStoreKey(stakingtypes.StoreKey) + storeService := runtime.NewKVStoreService(key) + testCtx := sdktestutil.DefaultContextWithDB(s.T(), key, storetypes.NewTransientStoreKey("transient_test")) + ctx := testCtx.Ctx.WithBlockHeader(cmtproto.Header{Time: cmttime.Now()}) + encCfg := moduletestutil.MakeTestEncodingConfig() + + ctrl := gomock.NewController(s.T()) + accountKeeper := testutil.NewMockAccountKeeper(ctrl) + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.BondedPoolName).Return(bondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.NotBondedPoolName).Return(notBondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().AddressCodec().Return(address.NewBech32Codec("cosmos")).AnyTimes() + + bankKeeper := testutil.NewMockBankKeeper(ctrl) + + keeper := stakingkeeper.NewKeeper( + encCfg.Codec, + storeService, + accountKeeper, + bankKeeper, + authtypes.NewModuleAddress(govtypes.ModuleName).String(), + address.NewBech32Codec("cosmosvaloper"), + address.NewBech32Codec("cosmosvalcons"), + tc.maxCacheSize, + ) + s.Require().NoError(keeper.SetParams(ctx, stakingtypes.DefaultParams())) + + blockTime := time.Now().UTC() + blockHeight := int64(1000) + ctx = ctx.WithBlockHeight(blockHeight).WithBlockTime(blockTime) + + iterator, err := keeper.RedelegationQueueIterator(ctx, blockTime) + s.Require().NoError(err) + defer iterator.Close() + count := 0 + for ; iterator.Valid(); iterator.Next() { + count++ + } + // no redelegations in the queue initially + s.Require().Equal(0, count) + + delAddrs, valAddrs := createValAddrs(3) + + // insert redelegation + red := stakingtypes.NewRedelegation(delAddrs[0], valAddrs[0], valAddrs[1], 0, + time.Unix(0, 0), math.NewInt(5), + math.LegacyNewDec(5), address.NewBech32Codec("cosmosvaloper"), address.NewBech32Codec("cosmos")) + + t := blockTime + s.Require().NoError(keeper.InsertRedelegationQueue(ctx, red, t)) + + // insert another redelegation + red1 := stakingtypes.NewRedelegation(delAddrs[1], valAddrs[1], valAddrs[0], 0, + time.Unix(0, 0), math.NewInt(5), + math.LegacyNewDec(5), address.NewBech32Codec("cosmosvaloper"), address.NewBech32Codec("cosmos")) + + s.Require().NoError(keeper.InsertRedelegationQueue(ctx, red1, t)) + + iterator1, err := keeper.RedelegationQueueIterator(ctx, blockTime) + s.Require().NoError(err) + defer iterator1.Close() + count1 := 0 + for ; iterator1.Valid(); iterator1.Next() { + count1++ + } + + // redelegation should be retrieved + // count 1 due to same redelegation time + s.Require().Equal(1, count1) + + // Verify GetRedelegationQueueTimeSlice returns the correct redelegations after insertion + reds, err := keeper.GetRedelegationQueueTimeSlice(ctx, blockTime) + s.Require().NoError(err) + s.Require().Equal(2, len(reds), "should have 2 redelegations at same time") + + // insert another redelegation with different redelegation time and height + red2 := stakingtypes.NewRedelegation(delAddrs[2], valAddrs[2], valAddrs[0], 0, + time.Unix(0, 0), math.NewInt(5), + math.LegacyNewDec(5), address.NewBech32Codec("cosmosvaloper"), address.NewBech32Codec("cosmos")) + t2 := blockTime.Add(-1 * time.Minute) + s.Require().NoError(keeper.InsertRedelegationQueue(ctx, red2, t2)) + + iterator2, err := keeper.RedelegationQueueIterator(ctx, blockTime) + s.Require().NoError(err) + defer iterator2.Close() + count2 := 0 + for ; iterator2.Valid(); iterator2.Next() { + count2++ + } + + // redelegation should be retrieved + s.Require().Equal(2, count2) + + // Verify the new redelegation was inserted at the correct time + reds2, err := keeper.GetRedelegationQueueTimeSlice(ctx, t2) + s.Require().NoError(err) + s.Require().Equal(1, len(reds2), "should have 1 redelegation at different time") + }) + } +} + +func (s *KeeperTestSuite) TestDequeueAllMatureRedelegationQueue() { + testCases := []struct { + name string + maxCacheSize int + numRedelegations int + }{ + { + name: "cache size < 0 (cache disabled)", + maxCacheSize: -1, + numRedelegations: 3, + }, + { + name: "cache size = 0 (unlimited cache)", + maxCacheSize: 0, + numRedelegations: 3, + }, + { + name: "cache size > redelegations", + maxCacheSize: 5, + numRedelegations: 2, + }, + { + name: "cache size == redelegations", + maxCacheSize: 2, + numRedelegations: 2, + }, + { + name: "cache size < redelegations", + maxCacheSize: 1, + numRedelegations: 3, + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + key := storetypes.NewKVStoreKey(stakingtypes.StoreKey) + storeService := runtime.NewKVStoreService(key) + testCtx := sdktestutil.DefaultContextWithDB(s.T(), key, storetypes.NewTransientStoreKey("transient_test")) + ctx := testCtx.Ctx.WithBlockHeader(cmtproto.Header{Time: cmttime.Now()}) + encCfg := moduletestutil.MakeTestEncodingConfig() + + ctrl := gomock.NewController(s.T()) + accountKeeper := testutil.NewMockAccountKeeper(ctrl) + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.BondedPoolName).Return(bondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.NotBondedPoolName).Return(notBondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().AddressCodec().Return(address.NewBech32Codec("cosmos")).AnyTimes() + + bankKeeper := testutil.NewMockBankKeeper(ctrl) + bankKeeper.EXPECT().DelegateCoinsFromAccountToModule(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + bankKeeper.EXPECT().SendCoinsFromModuleToModule(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + + // Initialize keeper with specific cache size + keeper := stakingkeeper.NewKeeper( + encCfg.Codec, + storeService, + accountKeeper, + bankKeeper, + authtypes.NewModuleAddress(govtypes.ModuleName).String(), + address.NewBech32Codec("cosmosvaloper"), + address.NewBech32Codec("cosmosvalcons"), + tc.maxCacheSize, + ) + params := stakingtypes.DefaultParams() + params.UnbondingTime = 1 * time.Second // Short unbonding time for testing + s.Require().NoError(keeper.SetParams(ctx, params)) + + blockTime := time.Now().UTC() + ctx = ctx.WithBlockTime(blockTime) + + // Create 2 validators + valAddr1 := sdk.ValAddress(PKs[0].Address()) + validator1 := testutil.NewValidator(s.T(), valAddr1, PKs[0]) + validator1, _ = validator1.AddTokensFromDel(keeper.TokensFromConsensusPower(ctx, 100)) + validator1 = stakingkeeper.TestingUpdateValidator(keeper, ctx, validator1, true) + + valAddr2 := sdk.ValAddress(PKs[1].Address()) + validator2 := testutil.NewValidator(s.T(), valAddr2, PKs[1]) + validator2, _ = validator2.AddTokensFromDel(keeper.TokensFromConsensusPower(ctx, 100)) + stakingkeeper.TestingUpdateValidator(keeper, ctx, validator2, true) + + // Create multiple redelegations + delAddrs, _ := createValAddrs(tc.numRedelegations) + for i := 0; i < tc.numRedelegations; i++ { + // Delegate to validator1 + bondAmt := keeper.TokensFromConsensusPower(ctx, 10) + _, err := keeper.Delegate(ctx, delAddrs[i], bondAmt, stakingtypes.Unbonded, validator1, true) + s.Require().NoError(err) + + // Redelegate from validator1 to validator2 + _, err = keeper.BeginRedelegation(ctx, delAddrs[i], valAddr1, valAddr2, math.LegacyNewDec(5)) + s.Require().NoError(err) + } + + // Verify redelegations were created + for i := 0; i < tc.numRedelegations; i++ { + _, err := keeper.GetRedelegation(ctx, delAddrs[i], valAddr1, valAddr2) + s.Require().NoError(err) + } + + // Fast-forward time to maturity + ctx = ctx.WithBlockTime(blockTime.Add(params.UnbondingTime)) + + // Verify GetPendingRedelegations returns the expected number of redelegations + allReds, err := keeper.GetPendingRedelegations(ctx, ctx.BlockTime()) + s.Require().NoError(err) + s.Require().NotEmpty(allReds, "should have pending redelegations") + + // Verify GetRedelegationQueueTimeSlice returns the expected number of redelegations + // In this case, it should return all redelegations as all redelegations mature at the same time. + reds, err := keeper.GetRedelegationQueueTimeSlice(ctx, ctx.BlockTime()) + s.Require().NoError(err) + s.Require().Equal(tc.numRedelegations, len(reds)) + + // Dequeue and complete all mature redelegations + matureRedelegations, err := keeper.DequeueAllMatureRedelegationQueue(ctx, ctx.BlockTime()) + s.Require().NoError(err) + s.Require().Equal(tc.numRedelegations, len(matureRedelegations), "all redelegations should be mature") + + // Complete the redelegations + for _, dvvTriplet := range matureRedelegations { + delAddr, err := accountKeeper.AddressCodec().StringToBytes(dvvTriplet.DelegatorAddress) + s.Require().NoError(err) + valSrcAddr, err := keeper.ValidatorAddressCodec().StringToBytes(dvvTriplet.ValidatorSrcAddress) + s.Require().NoError(err) + valDstAddr, err := keeper.ValidatorAddressCodec().StringToBytes(dvvTriplet.ValidatorDstAddress) + s.Require().NoError(err) + _, err = keeper.CompleteRedelegation(ctx, delAddr, valSrcAddr, valDstAddr) + s.Require().NoError(err) + } + + // Verify all redelegations were completed (removed from store) + for i := 0; i < tc.numRedelegations; i++ { + _, err := keeper.GetRedelegation(ctx, delAddrs[i], valAddr1, valAddr2) + s.Require().ErrorIs(err, stakingtypes.ErrNoRedelegation, "redelegation should be completed and removed") + } + }) + } +} + +func (s *KeeperTestSuite) TestRedelegationQueueCacheRecovery() { + // This test verifies that when the cache is initially too small (exceeded), + // and then entries are dequeued, the cache can recover and be used again + // Cache size is based on the number of unique timestamps (keys), not individual entries + key := storetypes.NewKVStoreKey(stakingtypes.StoreKey) + storeService := runtime.NewKVStoreService(key) + testCtx := sdktestutil.DefaultContextWithDB(s.T(), key, storetypes.NewTransientStoreKey("transient_test")) + ctx := testCtx.Ctx.WithBlockHeader(cmtproto.Header{Time: cmttime.Now()}) + encCfg := moduletestutil.MakeTestEncodingConfig() + + ctrl := gomock.NewController(s.T()) + accountKeeper := testutil.NewMockAccountKeeper(ctrl) + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.BondedPoolName).Return(bondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.NotBondedPoolName).Return(notBondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().AddressCodec().Return(address.NewBech32Codec("cosmos")).AnyTimes() + + bankKeeper := testutil.NewMockBankKeeper(ctrl) + bankKeeper.EXPECT().DelegateCoinsFromAccountToModule(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + bankKeeper.EXPECT().SendCoinsFromModuleToModule(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + + // Initialize keeper with small cache size (2 timestamps) + maxCacheSize := 2 + keeper := stakingkeeper.NewKeeper( + encCfg.Codec, + storeService, + accountKeeper, + bankKeeper, + authtypes.NewModuleAddress(govtypes.ModuleName).String(), + address.NewBech32Codec("cosmosvaloper"), + address.NewBech32Codec("cosmosvalcons"), + maxCacheSize, + ) + params := stakingtypes.DefaultParams() + params.UnbondingTime = 1 * time.Hour // Use 1 hour so we can create different timestamps + s.Require().NoError(keeper.SetParams(ctx, params)) + + baseTime := time.Now().UTC() + ctx = ctx.WithBlockTime(baseTime) + + // Create 2 validators + valAddr1 := sdk.ValAddress(PKs[0].Address()) + validator1 := testutil.NewValidator(s.T(), valAddr1, PKs[0]) + validator1, _ = validator1.AddTokensFromDel(keeper.TokensFromConsensusPower(ctx, 100)) + validator1 = stakingkeeper.TestingUpdateValidator(keeper, ctx, validator1, true) + + valAddr2 := sdk.ValAddress(PKs[1].Address()) + validator2 := testutil.NewValidator(s.T(), valAddr2, PKs[1]) + validator2, _ = validator2.AddTokensFromDel(keeper.TokensFromConsensusPower(ctx, 100)) + stakingkeeper.TestingUpdateValidator(keeper, ctx, validator2, true) + + // Create redelegations at 5 different timestamps (exceeds cache size of 2) + numTimestamps := 5 + delAddrs, _ := createValAddrs(numTimestamps) + completionTimes := make([]time.Time, numTimestamps) + + for i := 0; i < numTimestamps; i++ { + // Set different block times to create different completion timestamps + currentTime := baseTime.Add(time.Duration(i) * time.Second) + ctx = ctx.WithBlockTime(currentTime) + completionTimes[i] = currentTime.Add(params.UnbondingTime) + + // Delegate to validator1 + bondAmt := keeper.TokensFromConsensusPower(ctx, 10) + _, err := keeper.Delegate(ctx, delAddrs[i], bondAmt, stakingtypes.Unbonded, validator1, true) + s.Require().NoError(err) + + // Redelegate from validator1 to validator2 (will create redelegation with unique completion time) + _, err = keeper.BeginRedelegation(ctx, delAddrs[i], valAddr1, valAddr2, math.LegacyNewDec(5)) + s.Require().NoError(err) + } + + // Verify all redelegations were created + for i := 0; i < numTimestamps; i++ { + _, err := keeper.GetRedelegation(ctx, delAddrs[i], valAddr1, valAddr2) + s.Require().NoError(err) + } + + // At this point, cache should be exceeded (5 timestamps > maxCacheSize of 2) + // GetPendingRedelegations should still work, but will read from store + ctx = ctx.WithBlockTime(completionTimes[numTimestamps-1]) + allReds, err := keeper.GetPendingRedelegations(ctx, ctx.BlockTime()) + s.Require().NoError(err) + s.Require().Equal(5, len(allReds), "should have 5 different timestamps") + + // Fast-forward time to mature the first 3 timestamps and dequeue them + ctx = ctx.WithBlockTime(completionTimes[2]) + matureRedelegations, err := keeper.DequeueAllMatureRedelegationQueue(ctx, ctx.BlockTime()) + s.Require().NoError(err) + s.Require().Equal(3, len(matureRedelegations), "should dequeue 3 redelegations from 3 timestamps") + + // Complete the first 3 redelegations + for i := 0; i < 3; i++ { + _, err = keeper.CompleteRedelegation(ctx, delAddrs[i], valAddr1, valAddr2) + s.Require().NoError(err) + } + + // Verify the first 3 were removed + for i := 0; i < 3; i++ { + _, err := keeper.GetRedelegation(ctx, delAddrs[i], valAddr1, valAddr2) + s.Require().ErrorIs(err, stakingtypes.ErrNoRedelegation, "redelegation should be completed and removed") + } + + // Verify the last 2 still exist + for i := 3; i < numTimestamps; i++ { + _, err := keeper.GetRedelegation(ctx, delAddrs[i], valAddr1, valAddr2) + s.Require().NoError(err, "redelegation should still exist") + } + + // Now only 2 timestamps remain (completionTimes[3] and completionTimes[4]) + // This fits in the cache (2 timestamps == maxCacheSize) + // GetPendingRedelegations should now be able to use the cache + ctx = ctx.WithBlockTime(completionTimes[4]) + remainingReds, err := keeper.GetPendingRedelegations(ctx, ctx.BlockTime()) + s.Require().NoError(err) + s.Require().Equal(2, len(remainingReds), "should have 2 timestamps in cache") + + // Dequeue the remaining 2 + finalMatureRedelegations, err := keeper.DequeueAllMatureRedelegationQueue(ctx, ctx.BlockTime()) + s.Require().NoError(err) + s.Require().Equal(2, len(finalMatureRedelegations), "should have 2 mature redelegations") + + // Complete them + for i := 3; i < numTimestamps; i++ { + _, err = keeper.CompleteRedelegation(ctx, delAddrs[i], valAddr1, valAddr2) + s.Require().NoError(err) + } + + // Verify all redelegations are now completed + for i := 3; i < numTimestamps; i++ { + _, err := keeper.GetRedelegation(ctx, delAddrs[i], valAddr1, valAddr2) + s.Require().ErrorIs(err, stakingtypes.ErrNoRedelegation, "all redelegations should be completed") + } +} + +func (s *KeeperTestSuite) TestGetAndParseRedelegationTimeKey() { + require := s.Require() + + blockTime := time.Now().UTC() + key := stakingtypes.GetRedelegationTimeKey(blockTime) + time, err := stakingtypes.ParseRedelegationTimeKey(key) + require.NoError(err) + require.Equal(blockTime, time) +} + +func (s *KeeperTestSuite) TestSortRedelegationQueueKeysByAscendingOrder() { + require := s.Require() + + currentTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC) + oneHourLater := currentTime.Add(1 * time.Hour) + oneHourBefore := currentTime.Add(-1 * time.Hour) + + keys := []string{ + sdk.FormatTimeString(oneHourLater), + sdk.FormatTimeString(oneHourBefore), + sdk.FormatTimeString(currentTime), + } + + stakingtypes.SortTimestampsByAscendingOrder(keys) + + // Verify sorting is correct - should be sorted by timestamp ascending order + for i := 0; i < len(keys)-1; i++ { + t1, err := sdk.ParseTime(keys[i]) + require.NoError(err) + t2, err := sdk.ParseTime(keys[i+1]) + require.NoError(err) + + // Current entry should be before or equal to next entry + require.True(t1.Before(t2) || t1.Equal(t2), "timestamps should be in ascending order") + + } + + firstTime, err := sdk.ParseTime(keys[0]) + require.NoError(err) + require.Equal(oneHourBefore, firstTime) + + lastTime, err := sdk.ParseTime(keys[len(keys)-1]) + require.NoError(err) + require.Equal(oneHourLater, lastTime) +} diff --git a/x/staking/keeper/keeper.go b/x/staking/keeper/keeper.go index 45948f72224c..ec3427f7d2a6 100644 --- a/x/staking/keeper/keeper.go +++ b/x/staking/keeper/keeper.go @@ -13,6 +13,7 @@ import ( "github.com/cosmos/cosmos-sdk/codec" sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/x/staking/cache" "github.com/cosmos/cosmos-sdk/x/staking/types" ) @@ -32,6 +33,8 @@ type Keeper struct { authority string validatorAddressCodec addresscodec.Codec consensusAddressCodec addresscodec.Codec + + cache *cache.ValidatorsQueueCache } // NewKeeper creates a new staking Keeper instance @@ -43,6 +46,7 @@ func NewKeeper( authority string, validatorAddressCodec addresscodec.Codec, consensusAddressCodec addresscodec.Codec, + maxCacheSize int, ) *Keeper { // ensure bonded and not bonded module accounts are set if addr := ak.GetModuleAddress(types.BondedPoolName); addr == nil { @@ -62,7 +66,7 @@ func NewKeeper( panic("validator and/or consensus address codec are nil") } - return &Keeper{ + k := &Keeper{ storeService: storeService, cdc: cdc, authKeeper: ak, @@ -72,6 +76,18 @@ func NewKeeper( validatorAddressCodec: validatorAddressCodec, consensusAddressCodec: consensusAddressCodec, } + + if maxCacheSize >= 0 { + k.cache = cache.NewValidatorsQueueCache( + uint(maxCacheSize), + k.Logger, + k.GetAllUnbondingValidatorsFromStore, + k.GetAllUnbondingDelegationsQueueFromStore, + k.GetAllRedelegationsQueueFromStore, + ) + } + + return k } // Logger returns a module-specific logger. diff --git a/x/staking/keeper/keeper_test.go b/x/staking/keeper/keeper_test.go index ce2000f733b8..a719ec62951b 100644 --- a/x/staking/keeper/keeper_test.go +++ b/x/staking/keeper/keeper_test.go @@ -66,6 +66,7 @@ func (s *KeeperTestSuite) SetupTest() { authtypes.NewModuleAddress(govtypes.ModuleName).String(), address.NewBech32Codec("cosmosvaloper"), address.NewBech32Codec("cosmosvalcons"), + 1000, ) require.NoError(keeper.SetParams(ctx, stakingtypes.DefaultParams())) diff --git a/x/staking/keeper/val_state_change.go b/x/staking/keeper/val_state_change.go index 97fceb707644..e5a1d5d4759d 100644 --- a/x/staking/keeper/val_state_change.go +++ b/x/staking/keeper/val_state_change.go @@ -40,8 +40,10 @@ func (k Keeper) BlockValidatorUpdates(ctx context.Context) ([]abci.ValidatorUpda } sdkCtx := sdk.UnwrapSDKContext(ctx) + blockTime := sdkCtx.BlockTime() + // Remove all mature unbonding delegations from the ubd queue. - matureUnbonds, err := k.DequeueAllMatureUBDQueue(ctx, sdkCtx.BlockHeader().Time) + matureUnbonds, err := k.DequeueAllMatureUBDQueue(ctx, blockTime) if err != nil { return nil, err } @@ -70,9 +72,8 @@ func (k Keeper) BlockValidatorUpdates(ctx context.Context) ([]abci.ValidatorUpda ), ) } - // Remove all mature redelegations from the red queue. - matureRedelegations, err := k.DequeueAllMatureRedelegationQueue(ctx, sdkCtx.BlockHeader().Time) + matureRedelegations, err := k.DequeueAllMatureRedelegationQueue(ctx, blockTime) if err != nil { return nil, err } @@ -173,13 +174,13 @@ func (k Keeper) ApplyAndReturnValidatorSetUpdates(ctx context.Context) (updates case validator.IsUnbonded(): validator, err = k.unbondedToBonded(ctx, validator) if err != nil { - return + return updates, err } amtFromNotBondedToBonded = amtFromNotBondedToBonded.Add(validator.GetTokens()) case validator.IsUnbonding(): validator, err = k.unbondingToBonded(ctx, validator) if err != nil { - return + return updates, err } amtFromNotBondedToBonded = amtFromNotBondedToBonded.Add(validator.GetTokens()) case validator.IsBonded(): diff --git a/x/staking/keeper/validator.go b/x/staking/keeper/validator.go index de5827b2693c..538fa1d98f3e 100644 --- a/x/staking/keeper/validator.go +++ b/x/staking/keeper/validator.go @@ -447,6 +447,14 @@ func (k Keeper) GetLastValidators(ctx context.Context) (validators []types.Valid // GetUnbondingValidators returns a slice of mature validator addresses that // complete their unbonding at a given time and height. func (k Keeper) GetUnbondingValidators(ctx context.Context, endTime time.Time, endHeight int64) ([]string, error) { + if k.cache != nil { + cachedAddrs, err := k.cache.GetUnbondingValidatorsQueueEntry(ctx, endTime, endHeight) + if err == nil { + return cachedAddrs, nil + } + k.Logger(ctx).Error("GetUnbondingValidators from cache failed. Error: %s", err) + } + store := k.storeService.OpenKVStore(ctx) bz, err := store.Get(types.GetValidatorQueueKey(endTime, endHeight)) @@ -474,7 +482,18 @@ func (k Keeper) SetUnbondingValidatorsQueue(ctx context.Context, endTime time.Ti if err != nil { return err } - return store.Set(types.GetValidatorQueueKey(endTime, endHeight), bz) + err = store.Set(types.GetValidatorQueueKey(endTime, endHeight), bz) + if err != nil { + return err + } + + if k.cache != nil { + err = k.cache.SetUnbondingValidatorQueueEntry(ctx, types.GetCacheValidatorQueueKey(endTime, endHeight), addrs) + if err != nil { + k.Logger(ctx).Error("SetUnbondingValidatorsQueue from cache failed. Error: %s", err) + } + } + return nil } // InsertUnbondingValidatorQueue inserts a given unbonding validator address into @@ -492,7 +511,14 @@ func (k Keeper) InsertUnbondingValidatorQueue(ctx context.Context, val types.Val // given height and time. func (k Keeper) DeleteValidatorQueueTimeSlice(ctx context.Context, endTime time.Time, endHeight int64) error { store := k.storeService.OpenKVStore(ctx) - return store.Delete(types.GetValidatorQueueKey(endTime, endHeight)) + err := store.Delete(types.GetValidatorQueueKey(endTime, endHeight)) + if err != nil { + return err + } + if k.cache != nil { + k.cache.DeleteUnbondingValidatorQueueEntry(types.GetCacheValidatorQueueKey(endTime, endHeight)) + } + return nil } // DeleteValidatorQueue removes a validator by address from the unbonding queue @@ -529,6 +555,12 @@ func (k Keeper) DeleteValidatorQueue(ctx context.Context, val types.Validator) e return k.SetUnbondingValidatorsQueue(ctx, val.UnbondingTime, val.UnbondingHeight, newAddrs) } +// ValidatorQueueIteratorAll gets all the validators that are unbonding +func (k Keeper) ValidatorQueueIteratorAll(ctx context.Context) (corestore.Iterator, error) { + store := k.storeService.OpenKVStore(ctx) + return store.Iterator(types.ValidatorQueueKey, storetypes.PrefixEndBytes(types.ValidatorQueueKey)) +} + // ValidatorQueueIterator returns an interator ranging over validators that are // unbonding whose unbonding completion occurs at the given height and time. func (k Keeper) ValidatorQueueIterator(ctx context.Context, endTime time.Time, endHeight int64) (corestore.Iterator, error) { @@ -543,34 +575,33 @@ func (k Keeper) UnbondAllMatureValidators(ctx context.Context) error { blockTime := sdkCtx.BlockTime() blockHeight := sdkCtx.BlockHeight() - // unbondingValIterator will contains all validator addresses indexed under - // the ValidatorQueueKey prefix. Note, the entire index key is composed as - // ValidatorQueueKey | timeBzLen (8-byte big endian) | timeBz | heightBz (8-byte big endian), - // so it may be possible that certain validator addresses that are iterated - // over are not ready to unbond, so an explicit check is required. - unbondingValIterator, err := k.ValidatorQueueIterator(ctx, blockTime, blockHeight) + unbondingValidators, err := k.GetPendingUnbondingValidators(ctx, blockTime, blockHeight) if err != nil { return err } - defer unbondingValIterator.Close() - for ; unbondingValIterator.Valid(); unbondingValIterator.Next() { - key := unbondingValIterator.Key() - keyTime, keyHeight, err := types.ParseValidatorQueueKey(key) + keys := make([]string, 0, len(unbondingValidators)) + for k := range unbondingValidators { + keys = append(keys, k) + } + + types.SortValidatorQueueKeysByAscendingTimestampOrder(keys) + + for _, key := range keys { + time, height, err := types.ParseCacheValidatorQueueKey(key) if err != nil { return fmt.Errorf("failed to parse unbonding key: %w", err) } + if nonMature := time.After(blockTime); nonMature { + return nil + } + // All addresses for the given key have the same unbonding height and time. // We only unbond if the height and time are less than the current height // and time. - if keyHeight <= blockHeight && (keyTime.Before(blockTime) || keyTime.Equal(blockTime)) { - addrs := types.ValAddresses{} - if err = k.cdc.Unmarshal(unbondingValIterator.Value(), &addrs); err != nil { - return err - } - - for _, valAddr := range addrs.Addresses { + if height <= blockHeight && (time.Before(blockTime) || time.Equal(blockTime)) { + for _, valAddr := range unbondingValidators[key] { addr, err := k.validatorAddressCodec.StringToBytes(valAddr) if err != nil { return err @@ -633,3 +664,67 @@ func (k Keeper) GetPubKeyByConsAddr(ctx context.Context, addr sdk.ConsAddress) ( return pubkey, nil } + +// GetPendingUnbondingValidators gets unbonding validators from the cache or the store +func (k Keeper) GetPendingUnbondingValidators(ctx context.Context, endTime time.Time, endHeight int64) (map[string][]string, error) { + if k.cache != nil { + addrs, err := k.cache.GetUnbondingValidatorsQueue(ctx) + if err == nil { + return addrs, nil + } + k.Logger(ctx).Error("GetPendingUnbondingValidators from cache failed. Error: %s", err) + } + return k.GetUnbondingValidatorsFromStore(ctx, endTime, endHeight) +} + +// GetUnbondingValidatorsFromStore gets unbonding validators from the store for a given height and time. +func (k Keeper) GetUnbondingValidatorsFromStore(ctx context.Context, endTime time.Time, endHeight int64) (map[string][]string, error) { + iterator, err := k.ValidatorQueueIterator(ctx, endTime, endHeight) + if err != nil { + return nil, err + } + defer iterator.Close() + unbondingValidators, err := k.getUnbondingValidatorsFromIterator(iterator) + if err != nil { + return nil, err + } + + return unbondingValidators, nil +} + +// GetAllUnbondingValidatorsFromStore gets unbonding validators from the store +func (k Keeper) GetAllUnbondingValidatorsFromStore(ctx context.Context) (map[string][]string, error) { + iterator, err := k.ValidatorQueueIteratorAll(ctx) + if err != nil { + return nil, err + } + defer iterator.Close() + unbondingValidators, err := k.getUnbondingValidatorsFromIterator(iterator) + if err != nil { + return nil, err + } + + return unbondingValidators, nil +} + +// getUnbondingValidatorsFromIterator gets unbonding validators from the iterator. +func (k Keeper) getUnbondingValidatorsFromIterator(iterator corestore.Iterator) (map[string][]string, error) { + unbondingValidators := make(map[string][]string) + + for ; iterator.Valid(); iterator.Next() { + key := iterator.Key() + keyTime, keyHeight, err := types.ParseValidatorQueueKey(key) + if err != nil { + return nil, fmt.Errorf("failed to parse unbonding key: %w", err) + } + + addrs := types.ValAddresses{} + if err = k.cdc.Unmarshal(iterator.Value(), &addrs); err != nil { + return nil, err + } + + unbondingValidators[types.GetCacheValidatorQueueKey(keyTime, keyHeight)] = addrs.Addresses + } + + return unbondingValidators, nil +} diff --git a/x/staking/keeper/validator_test.go b/x/staking/keeper/validator_test.go index f5015f7844a2..4b37a60f0dcc 100644 --- a/x/staking/keeper/validator_test.go +++ b/x/staking/keeper/validator_test.go @@ -4,11 +4,20 @@ import ( "time" abci "github.com/cometbft/cometbft/abci/types" + cmtproto "github.com/cometbft/cometbft/proto/tendermint/types" + cmttime "github.com/cometbft/cometbft/types/time" "github.com/golang/mock/gomock" "cosmossdk.io/math" + storetypes "cosmossdk.io/store/types" + "github.com/cosmos/cosmos-sdk/codec/address" + "github.com/cosmos/cosmos-sdk/runtime" + sdktestutil "github.com/cosmos/cosmos-sdk/testutil" sdk "github.com/cosmos/cosmos-sdk/types" + moduletestutil "github.com/cosmos/cosmos-sdk/types/module/testutil" + authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" + govtypes "github.com/cosmos/cosmos-sdk/x/gov/types" stakingkeeper "github.com/cosmos/cosmos-sdk/x/staking/keeper" "github.com/cosmos/cosmos-sdk/x/staking/testutil" stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types" @@ -440,3 +449,666 @@ func (s *KeeperTestSuite) TestUnbondingValidator() { require.NoError(err) require.Equal(stakingtypes.Unbonded, validator.Status) } + +func (s *KeeperTestSuite) TestGetAllPendingUnbondingValidators() { + testCases := []struct { + name string + maxCacheSize int + description string + }{ + { + name: "cache size < 0 (cache disabled)", + maxCacheSize: -1, + description: "should always read from store when cache is not initialized", + }, + { + name: "cache size = 0 (unlimited cache)", + maxCacheSize: 0, + description: "should use unlimited cache with no size restrictions", + }, + { + name: "cache size > unbonding queue entries", + maxCacheSize: 10, + description: "should use cache when cache is large enough", + }, + { + name: "cache size == unbonding queue entries", + maxCacheSize: 2, + description: "should use cache when cache size matches entries", + }, + { + name: "cache size < unbonding queue entries", + maxCacheSize: 1, + description: "should fallback to store when cache size is exceeded", + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + key := storetypes.NewKVStoreKey(stakingtypes.StoreKey) + storeService := runtime.NewKVStoreService(key) + testCtx := sdktestutil.DefaultContextWithDB(s.T(), key, storetypes.NewTransientStoreKey("transient_test")) + ctx := testCtx.Ctx.WithBlockHeader(cmtproto.Header{Time: cmttime.Now()}) + encCfg := moduletestutil.MakeTestEncodingConfig() + + ctrl := gomock.NewController(s.T()) + accountKeeper := testutil.NewMockAccountKeeper(ctrl) + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.BondedPoolName).Return(bondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.NotBondedPoolName).Return(notBondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().AddressCodec().Return(address.NewBech32Codec("cosmos")).AnyTimes() + + bankKeeper := testutil.NewMockBankKeeper(ctrl) + + keeper := stakingkeeper.NewKeeper( + encCfg.Codec, + storeService, + accountKeeper, + bankKeeper, + authtypes.NewModuleAddress(govtypes.ModuleName).String(), + address.NewBech32Codec("cosmosvaloper"), + address.NewBech32Codec("cosmosvalcons"), + tc.maxCacheSize, + ) + s.Require().NoError(keeper.SetParams(ctx, stakingtypes.DefaultParams())) + + blockTime := time.Now().UTC() + blockHeight := int64(1000) + ctx = ctx.WithBlockHeight(blockHeight).WithBlockTime(blockTime) + + // add ready to unbond validator + valPubKey := PKs[0] + valAddr := sdk.ValAddress(valPubKey.Address().Bytes()) + val := testutil.NewValidator(s.T(), valAddr, valPubKey) + val.UnbondingHeight = blockHeight + val.UnbondingTime = blockTime + val.Status = stakingtypes.Unbonding + s.Require().NoError(keeper.SetValidator(ctx, val)) + s.Require().NoError(keeper.InsertUnbondingValidatorQueue(ctx, val)) + + // add another unbonding validator + valAddr1 := sdk.ValAddress(PKs[1].Address().Bytes()) + validator1 := testutil.NewValidator(s.T(), valAddr1, PKs[1]) + valUnbondingHeight1 := blockHeight - 10 + valUnbondingTime1 := blockTime.Add(-1 * time.Minute) + validator1.UnbondingHeight = valUnbondingHeight1 + validator1.UnbondingTime = valUnbondingTime1 + validator1.Status = stakingtypes.Unbonding + s.Require().NoError(keeper.SetValidator(ctx, validator1)) + s.Require().NoError(keeper.InsertUnbondingValidatorQueue(ctx, validator1)) + + // get pending unbonding validators should return the inserted validators + unbondingValidators, err := keeper.GetPendingUnbondingValidators(ctx, val.UnbondingTime, val.UnbondingHeight) + s.Require().NoError(err) + s.Require().Equal(2, len(unbondingValidators)) + s.Require().Equal(val.GetOperator(), unbondingValidators[stakingtypes.GetCacheValidatorQueueKey(val.UnbondingTime, val.UnbondingHeight)][0]) + s.Require().Equal(validator1.GetOperator(), unbondingValidators[stakingtypes.GetCacheValidatorQueueKey(validator1.UnbondingTime, validator1.UnbondingHeight)][0]) + + // Test calling again to verify cache consistency + unbondingValidators2, err := keeper.GetPendingUnbondingValidators(ctx, val.UnbondingTime, val.UnbondingHeight) + s.Require().NoError(err) + s.Require().Equal(len(unbondingValidators), len(unbondingValidators2), "repeated call should return same number of entries") + }) + } +} + +func (s *KeeperTestSuite) TestInsertUnbondingValidatorQueue() { + testCases := []struct { + name string + maxCacheSize int + description string + }{ + { + name: "cache size < 0 (cache disabled)", + maxCacheSize: -1, + description: "should always write to store when cache is not initialized", + }, + { + name: "cache size = 0 (unlimited cache)", + maxCacheSize: 0, + description: "should use unlimited cache with no size restrictions", + }, + { + name: "cache size > unbonding queue entries", + maxCacheSize: 10, + description: "should use cache when cache is large enough", + }, + { + name: "cache size == unbonding queue entries", + maxCacheSize: 2, + description: "should use cache when cache size matches entries", + }, + { + name: "cache size < unbonding queue entries", + maxCacheSize: 1, + description: "should fallback to store when cache size is exceeded", + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + key := storetypes.NewKVStoreKey(stakingtypes.StoreKey) + storeService := runtime.NewKVStoreService(key) + testCtx := sdktestutil.DefaultContextWithDB(s.T(), key, storetypes.NewTransientStoreKey("transient_test")) + ctx := testCtx.Ctx.WithBlockHeader(cmtproto.Header{Time: cmttime.Now()}) + encCfg := moduletestutil.MakeTestEncodingConfig() + + ctrl := gomock.NewController(s.T()) + accountKeeper := testutil.NewMockAccountKeeper(ctrl) + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.BondedPoolName).Return(bondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.NotBondedPoolName).Return(notBondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().AddressCodec().Return(address.NewBech32Codec("cosmos")).AnyTimes() + + bankKeeper := testutil.NewMockBankKeeper(ctrl) + + // Initialize keeper with specific cache size + keeper := stakingkeeper.NewKeeper( + encCfg.Codec, + storeService, + accountKeeper, + bankKeeper, + authtypes.NewModuleAddress(govtypes.ModuleName).String(), + address.NewBech32Codec("cosmosvaloper"), + address.NewBech32Codec("cosmosvalcons"), + tc.maxCacheSize, + ) + s.Require().NoError(keeper.SetParams(ctx, stakingtypes.DefaultParams())) + + blockTime := time.Now().UTC() + blockHeight := int64(1000) + ctx = ctx.WithBlockHeight(blockHeight).WithBlockTime(blockTime) + + iterator, err := keeper.ValidatorQueueIterator(ctx, blockTime, blockHeight) + s.Require().NoError(err) + defer iterator.Close() + count := 0 + for ; iterator.Valid(); iterator.Next() { + count++ + } + // no unbonding validator in the queue initially + s.Require().Equal(0, count) + + // add ready to unbond validator + valPubKey := PKs[0] + valAddr := sdk.ValAddress(valPubKey.Address().Bytes()) + val := testutil.NewValidator(s.T(), valAddr, valPubKey) + val.UnbondingHeight = blockHeight + val.UnbondingTime = blockTime + val.Status = stakingtypes.Unbonding + s.Require().NoError(keeper.SetValidator(ctx, val)) + s.Require().NoError(keeper.InsertUnbondingValidatorQueue(ctx, val)) + + // add another unbonding validator with same unbonding time and height + valAddr1 := sdk.ValAddress(PKs[1].Address().Bytes()) + validator1 := testutil.NewValidator(s.T(), valAddr1, PKs[1]) + valUnbondingHeight1 := blockHeight + valUnbondingTime1 := blockTime + validator1.UnbondingHeight = valUnbondingHeight1 + validator1.UnbondingTime = valUnbondingTime1 + validator1.Status = stakingtypes.Unbonding + s.Require().NoError(keeper.SetValidator(ctx, validator1)) + s.Require().NoError(keeper.InsertUnbondingValidatorQueue(ctx, validator1)) + + iterator1, err := keeper.ValidatorQueueIterator(ctx, blockTime, blockHeight) + s.Require().NoError(err) + defer iterator1.Close() + count1 := 0 + for ; iterator1.Valid(); iterator1.Next() { + count1++ + } + + // unbonding validator should be retrieved + // count 1 due to same unbonding time and height + s.Require().Equal(1, count1) + + // Verify GetUnbondingValidators returns the correct validators after insertion + unbondingVals, err := keeper.GetUnbondingValidators(ctx, blockTime, blockHeight) + s.Require().NoError(err) + s.Require().Equal(2, len(unbondingVals), "should have 2 validators at same time/height") + s.Require().Contains(unbondingVals, val.OperatorAddress) + s.Require().Contains(unbondingVals, validator1.OperatorAddress) + + // add another unbonding validator with different unbonding time and height + valAddr2 := sdk.ValAddress(PKs[2].Address().Bytes()) + validator2 := testutil.NewValidator(s.T(), valAddr2, PKs[2]) + valUnbondingHeight2 := blockHeight - 10 + valUnbondingTime2 := blockTime.Add(-1 * time.Minute) + validator2.UnbondingHeight = valUnbondingHeight2 + validator2.UnbondingTime = valUnbondingTime2 + validator2.Status = stakingtypes.Unbonding + s.Require().NoError(keeper.SetValidator(ctx, validator2)) + s.Require().NoError(keeper.InsertUnbondingValidatorQueue(ctx, validator2)) + + iterator2, err := keeper.ValidatorQueueIterator(ctx, blockTime, blockHeight) + s.Require().NoError(err) + defer iterator2.Close() + count2 := 0 + for ; iterator2.Valid(); iterator2.Next() { + count2++ + } + + // unbonding validator should be retrieved + s.Require().Equal(2, count2) + + // Verify the new validator was inserted at the correct time/height + unbondingVals2, err := keeper.GetUnbondingValidators(ctx, validator2.UnbondingTime, validator2.UnbondingHeight) + s.Require().NoError(err) + s.Require().Equal(1, len(unbondingVals2), "should have 1 validator at different time/height") + s.Require().Contains(unbondingVals2, validator2.OperatorAddress) + }) + } +} + +func (s *KeeperTestSuite) TestGetUnbondingValidators() { + testCases := []struct { + name string + maxCacheSize int + description string + }{ + { + name: "cache size < 0 (cache disabled)", + maxCacheSize: -1, + description: "should always read from store when cache is not initialized", + }, + { + name: "cache size = 0 (unlimited cache)", + maxCacheSize: 0, + description: "should use unlimited cache with no size restrictions", + }, + { + name: "cache size > unbonding queue entries", + maxCacheSize: 10, + description: "should use cache when cache is large enough", + }, + { + name: "cache size == unbonding queue entries", + maxCacheSize: 3, + description: "should use cache when cache size matches entries", + }, + { + name: "cache size < unbonding queue entries", + maxCacheSize: 1, + description: "should fallback to store when cache size is exceeded", + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + key := storetypes.NewKVStoreKey(stakingtypes.StoreKey) + storeService := runtime.NewKVStoreService(key) + testCtx := sdktestutil.DefaultContextWithDB(s.T(), key, storetypes.NewTransientStoreKey("transient_test")) + ctx := testCtx.Ctx.WithBlockHeader(cmtproto.Header{Time: cmttime.Now()}) + encCfg := moduletestutil.MakeTestEncodingConfig() + + ctrl := gomock.NewController(s.T()) + accountKeeper := testutil.NewMockAccountKeeper(ctrl) + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.BondedPoolName).Return(bondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.NotBondedPoolName).Return(notBondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().AddressCodec().Return(address.NewBech32Codec("cosmos")).AnyTimes() + + bankKeeper := testutil.NewMockBankKeeper(ctrl) + + keeper := stakingkeeper.NewKeeper( + encCfg.Codec, + storeService, + accountKeeper, + bankKeeper, + authtypes.NewModuleAddress(govtypes.ModuleName).String(), + address.NewBech32Codec("cosmosvaloper"), + address.NewBech32Codec("cosmosvalcons"), + tc.maxCacheSize, + ) + s.Require().NoError(keeper.SetParams(ctx, stakingtypes.DefaultParams())) + + baseTime := time.Now().UTC() + baseHeight := int64(1000) + ctx = ctx.WithBlockHeight(baseHeight).WithBlockTime(baseTime) + + // Create validators unbonding at different times/heights + // Group 1: Two validators at baseTime, baseHeight + val1 := testutil.NewValidator(s.T(), sdk.ValAddress(PKs[0].Address().Bytes()), PKs[0]) + val1.UnbondingHeight = baseHeight + val1.UnbondingTime = baseTime + val1.Status = stakingtypes.Unbonding + s.Require().NoError(keeper.SetValidator(ctx, val1)) + s.Require().NoError(keeper.InsertUnbondingValidatorQueue(ctx, val1)) + + val2 := testutil.NewValidator(s.T(), sdk.ValAddress(PKs[1].Address().Bytes()), PKs[1]) + val2.UnbondingHeight = baseHeight + val2.UnbondingTime = baseTime + val2.Status = stakingtypes.Unbonding + s.Require().NoError(keeper.SetValidator(ctx, val2)) + s.Require().NoError(keeper.InsertUnbondingValidatorQueue(ctx, val2)) + + // Group 2: One validator at different time/height + val3 := testutil.NewValidator(s.T(), sdk.ValAddress(PKs[2].Address().Bytes()), PKs[2]) + val3.UnbondingHeight = baseHeight + 10 + val3.UnbondingTime = baseTime.Add(1 * time.Hour) + val3.Status = stakingtypes.Unbonding + s.Require().NoError(keeper.SetValidator(ctx, val3)) + s.Require().NoError(keeper.InsertUnbondingValidatorQueue(ctx, val3)) + + // Group 3: One validator at another different time/height + val4 := testutil.NewValidator(s.T(), sdk.ValAddress(PKs[3].Address().Bytes()), PKs[3]) + val4.UnbondingHeight = baseHeight - 5 + val4.UnbondingTime = baseTime.Add(-30 * time.Minute) + val4.Status = stakingtypes.Unbonding + s.Require().NoError(keeper.SetValidator(ctx, val4)) + s.Require().NoError(keeper.InsertUnbondingValidatorQueue(ctx, val4)) + + // Test 1: Get validators for group 1 (baseTime, baseHeight) + // Should return 2 validators + unbondingVals1, err := keeper.GetUnbondingValidators(ctx, baseTime, baseHeight) + s.Require().NoError(err) + s.Require().Equal(2, len(unbondingVals1), "should have 2 validators at baseTime/baseHeight") + + // Verify the correct validators are returned + s.Require().Contains(unbondingVals1, val1.OperatorAddress) + s.Require().Contains(unbondingVals1, val2.OperatorAddress) + + // Test 2: Get validators for group 2 (baseTime + 1 hour, baseHeight + 10) + // Should return 1 validator + unbondingVals2, err := keeper.GetUnbondingValidators(ctx, val3.UnbondingTime, val3.UnbondingHeight) + s.Require().NoError(err) + s.Require().Equal(1, len(unbondingVals2), "should have 1 validator at baseTime+1hour/baseHeight+10") + s.Require().Contains(unbondingVals2, val3.OperatorAddress) + + // Test 3: Get validators for group 3 (baseTime - 30 min, baseHeight - 5) + // Should return 1 validator + unbondingVals3, err := keeper.GetUnbondingValidators(ctx, val4.UnbondingTime, val4.UnbondingHeight) + s.Require().NoError(err) + s.Require().Equal(1, len(unbondingVals3), "should have 1 validator at baseTime-30min/baseHeight-5") + s.Require().Contains(unbondingVals3, val4.OperatorAddress) + + // Test 4: Get validators for a time/height with no validators + // Should return empty slice + emptyTime := baseTime.Add(2 * time.Hour) + emptyHeight := baseHeight + 100 + unbondingValsEmpty, err := keeper.GetUnbondingValidators(ctx, emptyTime, emptyHeight) + s.Require().NoError(err) + s.Require().Equal(0, len(unbondingValsEmpty), "should have 0 validators at non-existent time/height") + + // Test 5: Call GetUnbondingValidators again to verify cache consistency + // This ensures second call returns same results (cache hit scenario) + unbondingVals1Again, err := keeper.GetUnbondingValidators(ctx, baseTime, baseHeight) + s.Require().NoError(err) + s.Require().Equal(len(unbondingVals1), len(unbondingVals1Again), "repeated call should return same number of validators") + s.Require().ElementsMatch(unbondingVals1, unbondingVals1Again, "repeated call should return same validators") + + // Verify total validator count + allValidators, err := keeper.GetAllValidators(ctx) + s.Require().NoError(err) + s.Require().Equal(4, len(allValidators), "should have 4 total validators") + }) + } +} + +func (s *KeeperTestSuite) TestUnbondAllMatureValidators() { + testCases := []struct { + name string + maxCacheSize int + numUnbondingValidators int + }{ + { + name: "cache size < 0 (cache disabled)", + maxCacheSize: -1, + numUnbondingValidators: 3, + }, + { + name: "cache size = 0 (unlimited cache)", + maxCacheSize: 0, + numUnbondingValidators: 3, + }, + { + name: "cache size > unbonding validators", + maxCacheSize: 3, + numUnbondingValidators: 2, + }, + { + name: "cache size == unbonding validators", + maxCacheSize: 2, + numUnbondingValidators: 2, + }, + { + name: "cache size < unbonding validators", + maxCacheSize: 1, + numUnbondingValidators: 2, + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + key := storetypes.NewKVStoreKey(stakingtypes.StoreKey) + storeService := runtime.NewKVStoreService(key) + testCtx := sdktestutil.DefaultContextWithDB(s.T(), key, storetypes.NewTransientStoreKey("transient_test")) + ctx := testCtx.Ctx.WithBlockHeader(cmtproto.Header{Time: cmttime.Now()}) + encCfg := moduletestutil.MakeTestEncodingConfig() + + ctrl := gomock.NewController(s.T()) + accountKeeper := testutil.NewMockAccountKeeper(ctrl) + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.BondedPoolName).Return(bondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.NotBondedPoolName).Return(notBondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().AddressCodec().Return(address.NewBech32Codec("cosmos")).AnyTimes() + + bankKeeper := testutil.NewMockBankKeeper(ctrl) + + // Initialize keeper with specific cache size + keeper := stakingkeeper.NewKeeper( + encCfg.Codec, + storeService, + accountKeeper, + bankKeeper, + authtypes.NewModuleAddress(govtypes.ModuleName).String(), + address.NewBech32Codec("cosmosvaloper"), + address.NewBech32Codec("cosmosvalcons"), + tc.maxCacheSize, + ) + s.Require().NoError(keeper.SetParams(ctx, stakingtypes.DefaultParams())) + + blockTime := time.Now().UTC() + blockHeight := int64(1000) + ctx = ctx.WithBlockHeight(blockHeight).WithBlockTime(blockTime) + + // Create multiple unbonding validators that are ready to unbond + for i := 0; i < tc.numUnbondingValidators; i++ { + valPubKey := PKs[i] + valAddr := sdk.ValAddress(valPubKey.Address().Bytes()) + val := testutil.NewValidator(s.T(), valAddr, valPubKey) + val.UnbondingHeight = blockHeight + val.UnbondingTime = blockTime + val.Status = stakingtypes.Unbonding + s.Require().NoError(keeper.SetValidator(ctx, val)) + s.Require().NoError(keeper.InsertUnbondingValidatorQueue(ctx, val)) + } + + // Verify we have the expected number of validators before unbonding + allValidators, err := keeper.GetAllValidators(ctx) + s.Require().NoError(err) + s.Require().Equal(tc.numUnbondingValidators, len(allValidators), "should have all validators before unbonding") + + // Verify GetUnbondingValidators returns the expected number of validators. + // In this case, it should return all validators as all validators are unbonding at the same height and time. + unbondingValidators, err := keeper.GetUnbondingValidators(ctx, blockTime, blockHeight) + s.Require().NoError(err) + s.Require().Equal(tc.numUnbondingValidators, len(unbondingValidators)) + + err = keeper.UnbondAllMatureValidators(ctx) + s.Require().NoError(err) + + // Verify all validators were unbonded (removed from the store) + allValidatorsAfter, err := keeper.GetAllValidators(ctx) + s.Require().NoError(err) + s.Require().Equal(0, len(allValidatorsAfter)) + }) + } +} + +func (s *KeeperTestSuite) TestUnbondingValidatorQueueCacheRecovery() { + // This test verifies that when the cache is initially too small (exceeded), + // and then entries are dequeued, the cache can recover and be used again + // Cache size is based on the number of unique time+height keys, not individual validators + key := storetypes.NewKVStoreKey(stakingtypes.StoreKey) + storeService := runtime.NewKVStoreService(key) + testCtx := sdktestutil.DefaultContextWithDB(s.T(), key, storetypes.NewTransientStoreKey("transient_test")) + ctx := testCtx.Ctx.WithBlockHeader(cmtproto.Header{Time: cmttime.Now()}) + encCfg := moduletestutil.MakeTestEncodingConfig() + + ctrl := gomock.NewController(s.T()) + accountKeeper := testutil.NewMockAccountKeeper(ctrl) + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.BondedPoolName).Return(bondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().GetModuleAddress(stakingtypes.NotBondedPoolName).Return(notBondedAcc.GetAddress()).AnyTimes() + accountKeeper.EXPECT().AddressCodec().Return(address.NewBech32Codec("cosmos")).AnyTimes() + + bankKeeper := testutil.NewMockBankKeeper(ctrl) + + // Initialize keeper with small cache size (2 time+height keys) + maxCacheSize := 2 + keeper := stakingkeeper.NewKeeper( + encCfg.Codec, + storeService, + accountKeeper, + bankKeeper, + authtypes.NewModuleAddress(govtypes.ModuleName).String(), + address.NewBech32Codec("cosmosvaloper"), + address.NewBech32Codec("cosmosvalcons"), + maxCacheSize, + ) + params := stakingtypes.DefaultParams() + params.UnbondingTime = 1 * time.Hour // Use 1 hour so we can create different timestamps + s.Require().NoError(keeper.SetParams(ctx, params)) + + baseTime := time.Now().UTC() + baseHeight := int64(1000) + + // Create validators at 5 different time+height combinations (exceeds cache size of 2) + numKeys := 5 + validators := make([]stakingtypes.Validator, numKeys) + unbondingTimes := make([]time.Time, numKeys) + unbondingHeights := make([]int64, numKeys) + + for i := 0; i < numKeys; i++ { + // Create unique time+height combinations + unbondingTimes[i] = baseTime.Add(time.Duration(i) * time.Hour) + unbondingHeights[i] = baseHeight + int64(i*10) + + valPubKey := PKs[i] + valAddr := sdk.ValAddress(valPubKey.Address().Bytes()) + validators[i] = testutil.NewValidator(s.T(), valAddr, valPubKey) + validators[i].UnbondingHeight = unbondingHeights[i] + validators[i].UnbondingTime = unbondingTimes[i] + validators[i].Status = stakingtypes.Unbonding + s.Require().NoError(keeper.SetValidator(ctx, validators[i])) + s.Require().NoError(keeper.InsertUnbondingValidatorQueue(ctx, validators[i])) + } + + // Verify all validators were created + for i := 0; i < numKeys; i++ { + _, err := keeper.GetValidator(ctx, sdk.ValAddress(PKs[i].Address().Bytes())) + s.Require().NoError(err) + } + + // At this point, cache should be exceeded (5 keys > maxCacheSize of 2) + // GetPendingUnbondingValidators should still work, but will read from store + ctx = ctx.WithBlockTime(unbondingTimes[numKeys-1]).WithBlockHeight(unbondingHeights[numKeys-1]) + allValidators, err := keeper.GetPendingUnbondingValidators(ctx, unbondingTimes[numKeys-1], unbondingHeights[numKeys-1]) + s.Require().NoError(err) + s.Require().Equal(5, len(allValidators), "should have 5 different time+height keys") + + // Mature and unbond the first 3 validators (removing 3 keys) + // Fast-forward to time when first 3 validators are mature + ctx = ctx.WithBlockTime(unbondingTimes[2]).WithBlockHeight(unbondingHeights[2]) + err = keeper.UnbondAllMatureValidators(ctx) + s.Require().NoError(err) + + // Verify the first 3 are now unbonded (not in unbonding queue) + for i := 0; i < 3; i++ { + // Verify not in queue + vals, err := keeper.GetUnbondingValidators(ctx, unbondingTimes[i], unbondingHeights[i]) + s.Require().NoError(err) + s.Require().Equal(0, len(vals), "should have no validators at this time+height") + } + + // Verify the last 2 are still unbonding + for i := 3; i < numKeys; i++ { + valAddr := sdk.ValAddress(PKs[i].Address().Bytes()) + validator, err := keeper.GetValidator(ctx, valAddr) + s.Require().NoError(err) + s.Require().Equal(stakingtypes.Unbonding, validator.Status, "validator should still be unbonding") + + // Verify still in queue + vals, err := keeper.GetUnbondingValidators(ctx, unbondingTimes[i], unbondingHeights[i]) + s.Require().NoError(err) + s.Require().Equal(1, len(vals), "should have 1 validator at this time+height") + } + + // Now only 2 time+height keys remain + // This fits in the cache (2 keys == maxCacheSize) + // GetPendingUnbondingValidators should now be able to use the cache + ctx = ctx.WithBlockTime(unbondingTimes[4]).WithBlockHeight(unbondingHeights[4]) + remainingValidators, err := keeper.GetPendingUnbondingValidators(ctx, unbondingTimes[4], unbondingHeights[4]) + s.Require().NoError(err) + s.Require().Equal(2, len(remainingValidators), "should have 2 time+height keys in cache") + + // Unbond the remaining 2 validators + // Fast-forward to time when all remaining validators are mature + ctx = ctx.WithBlockTime(unbondingTimes[4]).WithBlockHeight(unbondingHeights[4]) + err = keeper.UnbondAllMatureValidators(ctx) + s.Require().NoError(err) + + // Verify all validators are now unbonded + for i := 3; i < numKeys; i++ { + // Verify not in queue + vals, err := keeper.GetUnbondingValidators(ctx, unbondingTimes[i], unbondingHeights[i]) + s.Require().NoError(err) + s.Require().Equal(0, len(vals), "should have no validators in queue") + } +} + +func (s *KeeperTestSuite) TestSortValidatorQueueKeysByAscendingTimestampOrder() { + require := s.Require() + + currentTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC) + oneHourLater := currentTime.Add(1 * time.Hour) + oneHourBefore := currentTime.Add(-1 * time.Hour) + + keys := []string{ + stakingtypes.GetCacheValidatorQueueKey(oneHourLater, 1000), + stakingtypes.GetCacheValidatorQueueKey(oneHourBefore, 500), + stakingtypes.GetCacheValidatorQueueKey(currentTime, 750), + stakingtypes.GetCacheValidatorQueueKey(oneHourBefore, 600), + stakingtypes.GetCacheValidatorQueueKey(oneHourLater, 900), + } + + stakingtypes.SortValidatorQueueKeysByAscendingTimestampOrder(keys) + + // Verify sorting is correct - should be sorted by timestamp ascending order + for i := 0; i < len(keys)-1; i++ { + t1, _, err := stakingtypes.ParseCacheValidatorQueueKey(keys[i]) + require.NoError(err) + t2, _, err := stakingtypes.ParseCacheValidatorQueueKey(keys[i+1]) + require.NoError(err) + + // Current entry should be before or equal to next entry + require.True(t1.Before(t2) || t1.Equal(t2), "timestamps should be in ascending order") + + } + + firstTime, _, err := stakingtypes.ParseCacheValidatorQueueKey(keys[0]) + require.NoError(err) + require.Equal(oneHourBefore, firstTime) + + lastTime, _, err := stakingtypes.ParseCacheValidatorQueueKey(keys[len(keys)-1]) + require.NoError(err) + require.Equal(oneHourLater, lastTime) +} + +func (s *KeeperTestSuite) TestGetAndParseCacheValidatorQueueKey() { + require := s.Require() + + blockTime := time.Now().UTC() + blockHeight := int64(1000) + key := stakingtypes.GetCacheValidatorQueueKey(blockTime, blockHeight) + time, height, err := stakingtypes.ParseCacheValidatorQueueKey(key) + require.NoError(err) + require.Equal(blockTime, time) + require.Equal(blockHeight, height) +} diff --git a/x/staking/module.go b/x/staking/module.go index 9fde311dffae..fde83a88842b 100644 --- a/x/staking/module.go +++ b/x/staking/module.go @@ -205,6 +205,7 @@ type ModuleInputs struct { BankKeeper types.BankKeeper Cdc codec.Codec StoreService store.KVStoreService + MaxCacheSize int `optional:"true"` // LegacySubspace is used solely for migration of x/params managed parameters LegacySubspace exported.Subspace `optional:"true"` @@ -233,6 +234,7 @@ func ProvideModule(in ModuleInputs) ModuleOutputs { authority.String(), in.ValidatorAddressCodec, in.ConsensusAddressCodec, + in.MaxCacheSize, // MaxCacheSize defaults to 0 (unlimited cache) if not provided ) m := NewAppModule(in.Cdc, k, in.AccountKeeper, in.BankKeeper, in.LegacySubspace) return ModuleOutputs{StakingKeeper: k, Module: m} diff --git a/x/staking/types/errors.go b/x/staking/types/errors.go index 00441564352a..ad747fe203df 100644 --- a/x/staking/types/errors.go +++ b/x/staking/types/errors.go @@ -48,4 +48,5 @@ var ( ErrInvalidSigner = errors.Register(ModuleName, 43, "expected authority account as only signer for proposal message") ErrBadRedelegationSrc = errors.Register(ModuleName, 44, "redelegation source validator not found") ErrNoUnbondingType = errors.Register(ModuleName, 45, "unbonding type not found") + ErrCacheMaxSizeReached = errors.Register(ModuleName, 46, "cache max size reached") ) diff --git a/x/staking/types/keys.go b/x/staking/types/keys.go index f9e7a985acd5..e24a16eacfcd 100644 --- a/x/staking/types/keys.go +++ b/x/staking/types/keys.go @@ -4,6 +4,9 @@ import ( "bytes" "encoding/binary" "fmt" + "sort" + "strconv" + "strings" "time" addresscodec "cosmossdk.io/core/address" @@ -170,6 +173,27 @@ func GetValidatorQueueKey(timestamp time.Time, height int64) []byte { return bz } +// GetCacheValidatorQueueKey returns a key for the cache for unbonding validators for a given time and height. +func GetCacheValidatorQueueKey(time time.Time, height int64) string { + return fmt.Sprintf("%s/%d", time.Format(sdk.SortableTimeFormat), height) +} + +// ParseCacheValidatorQueueKey parses a key for the cache for unbonding validators for a given time and height. +func ParseCacheValidatorQueueKey(key string) (time.Time, int64, error) { + parts := strings.Split(key, "/") + t, err := sdk.ParseTime(parts[0]) + if err != nil { + return time.Time{}, 0, err + } + + height, err := strconv.ParseInt(parts[1], 10, 64) + if err != nil { + return time.Time{}, 0, err + } + + return t, height, nil +} + // ParseValidatorQueueKey returns the encoded time and height from a key created // from GetValidatorQueueKey. func ParseValidatorQueueKey(bz []byte) (time.Time, int64, error) { @@ -283,6 +307,21 @@ func GetUnbondingDelegationTimeKey(timestamp time.Time) []byte { return append(UnbondingQueueKey, bz...) } +// ParseUnbondingDelegationTimeKey parses the unbonding delegation time key and returns the timestamp +func ParseUnbondingDelegationTimeKey(bz []byte) (time.Time, error) { + prefixL := len(UnbondingQueueKey) + if len(bz) <= prefixL { + return time.Time{}, fmt.Errorf("invalid key length; expected at least %d bytes, got %d", prefixL+1, len(bz)) + } + + if prefix := bz[:prefixL]; !bytes.Equal(prefix, UnbondingQueueKey) { + return time.Time{}, fmt.Errorf("invalid prefix; expected: %X, got: %X", UnbondingQueueKey, prefix) + } + + timeBz := bz[prefixL:] + return sdk.ParseTimeBytes(timeBz) +} + // GetREDKey returns a key prefix for indexing a redelegation from a delegator // and source validator to a destination validator. func GetREDKey(delAddr sdk.AccAddress, valSrcAddr, valDstAddr sdk.ValAddress) []byte { @@ -375,6 +414,21 @@ func GetRedelegationTimeKey(timestamp time.Time) []byte { return append(RedelegationQueueKey, bz...) } +// ParseRedelegationTimeKey parses the redelegation time key and returns the timestamp +func ParseRedelegationTimeKey(bz []byte) (time.Time, error) { + prefixL := len(RedelegationQueueKey) + if len(bz) <= prefixL { + return time.Time{}, fmt.Errorf("invalid key length; expected at least %d bytes, got %d", prefixL+1, len(bz)) + } + + if prefix := bz[:prefixL]; !bytes.Equal(prefix, RedelegationQueueKey) { + return time.Time{}, fmt.Errorf("invalid prefix; expected: %X, got: %X", RedelegationQueueKey, prefix) + } + + timeBz := bz[prefixL:] + return sdk.ParseTimeBytes(timeBz) +} + // GetREDsKey returns a key prefix for indexing a redelegation from a delegator // address. func GetREDsKey(delAddr sdk.AccAddress) []byte { @@ -405,3 +459,21 @@ func GetHistoricalInfoKey(height int64) []byte { binary.BigEndian.PutUint64(heightBytes, uint64(height)) return append(HistoricalInfoKey, heightBytes...) } + +// SortTimestampsByAscendingOrder sorts the timestamps by ascending order. +func SortTimestampsByAscendingOrder(keys []string) { + sort.Slice(keys, func(i, j int) bool { + t1, _ := sdk.ParseTime(keys[i]) + t2, _ := sdk.ParseTime(keys[j]) + return t1.Before(t2) + }) +} + +// SortValidatorQueueKeysByAscendingTimestampOrder sorts the validator queue keys by ascending timestamp. +func SortValidatorQueueKeysByAscendingTimestampOrder(keys []string) { + sort.Slice(keys, func(i, j int) bool { + t1, _, _ := ParseCacheValidatorQueueKey(keys[i]) + t2, _, _ := ParseCacheValidatorQueueKey(keys[j]) + return t1.Before(t2) + }) +}