diff --git a/protocol/testutil/constants/vault.go b/protocol/testutil/constants/vault.go new file mode 100644 index 0000000000..eba5e6120d --- /dev/null +++ b/protocol/testutil/constants/vault.go @@ -0,0 +1,16 @@ +package constants + +import ( + vaulttypes "github.com/dydxprotocol/v4-chain/protocol/x/vault/types" +) + +var ( + Vault_Clob_0 = vaulttypes.VaultId{ + Type: vaulttypes.VaultType_VAULT_TYPE_CLOB, + Number: 0, + } + Vault_Clob_1 = vaulttypes.VaultId{ + Type: vaulttypes.VaultType_VAULT_TYPE_CLOB, + Number: 1, + } +) diff --git a/protocol/x/vault/keeper/shares.go b/protocol/x/vault/keeper/shares.go new file mode 100644 index 0000000000..467af6cf24 --- /dev/null +++ b/protocol/x/vault/keeper/shares.go @@ -0,0 +1,41 @@ +package keeper + +import ( + "cosmossdk.io/store/prefix" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/dydxprotocol/v4-chain/protocol/dtypes" + "github.com/dydxprotocol/v4-chain/protocol/x/vault/types" +) + +// GetTotalShares gets TotalShares for a vault. +func (k Keeper) GetTotalShares( + ctx sdk.Context, + vaultId types.VaultId, +) (val types.NumShares, exists bool) { + store := prefix.NewStore(ctx.KVStore(k.storeKey), []byte(types.TotalSharesKeyPrefix)) + + b := store.Get(vaultId.ToStateKey()) + if b == nil { + return val, false + } + + k.cdc.MustUnmarshal(b, &val) + return val, true +} + +// SetTotalShares sets TotalShares for a vault. Returns error if `totalShares` is negative. +func (k Keeper) SetTotalShares( + ctx sdk.Context, + vaultId types.VaultId, + totalShares types.NumShares, +) error { + if totalShares.NumShares.Cmp(dtypes.NewInt(0)) == -1 { + return types.ErrNegativeShares + } + + b := k.cdc.MustMarshal(&totalShares) + totalSharesStore := prefix.NewStore(ctx.KVStore(k.storeKey), []byte(types.TotalSharesKeyPrefix)) + totalSharesStore.Set(vaultId.ToStateKey(), b) + + return nil +} diff --git a/protocol/x/vault/keeper/shares_test.go b/protocol/x/vault/keeper/shares_test.go new file mode 100644 index 0000000000..06454b0b80 --- /dev/null +++ b/protocol/x/vault/keeper/shares_test.go @@ -0,0 +1,67 @@ +package keeper_test + +import ( + "testing" + + "github.com/dydxprotocol/v4-chain/protocol/dtypes" + testapp "github.com/dydxprotocol/v4-chain/protocol/testutil/app" + "github.com/dydxprotocol/v4-chain/protocol/testutil/constants" + "github.com/dydxprotocol/v4-chain/protocol/x/vault/types" + "github.com/stretchr/testify/require" +) + +func TestGetSetTotalShares(t *testing.T) { + tApp := testapp.NewTestAppBuilder(t).Build() + ctx := tApp.InitChain() + k := tApp.App.VaultKeeper + + // Get total shares for a non-existing vault. + _, exists := k.GetTotalShares(ctx, constants.Vault_Clob_0) + require.Equal(t, false, exists) + + // Set total shares for a vault and then get. + err := k.SetTotalShares(ctx, constants.Vault_Clob_0, types.NumShares{ + NumShares: dtypes.NewInt(7), + }) + require.NoError(t, err) + numShares, exists := k.GetTotalShares(ctx, constants.Vault_Clob_0) + require.Equal(t, true, exists) + require.Equal(t, dtypes.NewInt(7), numShares.NumShares) + + // Set total shares for another vault and then get. + err = k.SetTotalShares(ctx, constants.Vault_Clob_1, types.NumShares{ + NumShares: dtypes.NewInt(456), + }) + require.NoError(t, err) + numShares, exists = k.GetTotalShares(ctx, constants.Vault_Clob_1) + require.Equal(t, true, exists) + require.Equal(t, dtypes.NewInt(456), numShares.NumShares) + + // Set total shares for second vault to 0. + err = k.SetTotalShares(ctx, constants.Vault_Clob_1, types.NumShares{ + NumShares: dtypes.NewInt(0), + }) + require.NoError(t, err) + numShares, exists = k.GetTotalShares(ctx, constants.Vault_Clob_1) + require.Equal(t, true, exists) + require.Equal(t, dtypes.NewInt(0), numShares.NumShares) + + // Set total shares for the first vault again and then get. + err = k.SetTotalShares(ctx, constants.Vault_Clob_0, types.NumShares{ + NumShares: dtypes.NewInt(123), + }) + require.NoError(t, err) + numShares, exists = k.GetTotalShares(ctx, constants.Vault_Clob_0) + require.Equal(t, true, exists) + require.Equal(t, dtypes.NewInt(123), numShares.NumShares) + + // Set total shares for the first vault to a negative value. + // Should get error and total shares should remain unchanged. + err = k.SetTotalShares(ctx, constants.Vault_Clob_0, types.NumShares{ + NumShares: dtypes.NewInt(-1), + }) + require.Equal(t, types.ErrNegativeShares, err) + numShares, exists = k.GetTotalShares(ctx, constants.Vault_Clob_0) + require.Equal(t, true, exists) + require.Equal(t, dtypes.NewInt(123), numShares.NumShares) +} diff --git a/protocol/x/vault/types/errors.go b/protocol/x/vault/types/errors.go new file mode 100644 index 0000000000..68d2ed719d --- /dev/null +++ b/protocol/x/vault/types/errors.go @@ -0,0 +1,13 @@ +package types + +// DONTCOVER + +import errorsmod "cosmossdk.io/errors" + +var ( + ErrNegativeShares = errorsmod.Register( + ModuleName, + 1, + "Shares are negative", + ) +) diff --git a/protocol/x/vault/types/keys.go b/protocol/x/vault/types/keys.go index 7e130df0e9..3ff5c4c37e 100644 --- a/protocol/x/vault/types/keys.go +++ b/protocol/x/vault/types/keys.go @@ -1,10 +1,16 @@ package types -// Module name and store keys +// Module name and store keys. const ( - // ModuleName defines the module name + // ModuleName defines the module name. ModuleName = "vault" - // StoreKey defines the primary module store key + // StoreKey defines the primary module store key. StoreKey = ModuleName ) + +// State. +const ( + // TotalSharesKeyPrefix is the prefix to retrieve all TotalShares. + TotalSharesKeyPrefix = "TotalShares:" +) diff --git a/protocol/x/vault/types/keys_test.go b/protocol/x/vault/types/keys_test.go new file mode 100644 index 0000000000..6b70d1828c --- /dev/null +++ b/protocol/x/vault/types/keys_test.go @@ -0,0 +1,17 @@ +package types_test + +import ( + "testing" + + "github.com/dydxprotocol/v4-chain/protocol/x/vault/types" + "github.com/stretchr/testify/require" +) + +func TestModuleKeys(t *testing.T) { + require.Equal(t, "vault", types.ModuleName) + require.Equal(t, "vault", types.StoreKey) +} + +func TestStateKeys(t *testing.T) { + require.Equal(t, "TotalShares:", types.TotalSharesKeyPrefix) +} diff --git a/protocol/x/vault/types/vault_id.go b/protocol/x/vault/types/vault_id.go new file mode 100644 index 0000000000..131d16f0f9 --- /dev/null +++ b/protocol/x/vault/types/vault_id.go @@ -0,0 +1,9 @@ +package types + +func (id *VaultId) ToStateKey() []byte { + b, err := id.Marshal() + if err != nil { + panic(err) + } + return b +} diff --git a/protocol/x/vault/types/vault_id_test.go b/protocol/x/vault/types/vault_id_test.go new file mode 100644 index 0000000000..b7e4c77b1f --- /dev/null +++ b/protocol/x/vault/types/vault_id_test.go @@ -0,0 +1,16 @@ +package types_test + +import ( + "testing" + + "github.com/dydxprotocol/v4-chain/protocol/testutil/constants" + "github.com/stretchr/testify/require" +) + +func TestToStateKey(t *testing.T) { + b, _ := constants.Vault_Clob_0.Marshal() + require.Equal(t, b, constants.Vault_Clob_0.ToStateKey()) + + b, _ = constants.Vault_Clob_1.Marshal() + require.Equal(t, b, constants.Vault_Clob_1.ToStateKey()) +}