Skip to content

Commit

Permalink
return error if shares is negative
Browse files Browse the repository at this point in the history
  • Loading branch information
tqin7 committed Mar 15, 2024
1 parent 283092d commit 5614707
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 4 deletions.
11 changes: 9 additions & 2 deletions protocol/x/vault/keeper/shares.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package keeper
import (
"cosmossdk.io/store/prefix"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/dydxprotocol/v4-chain/protocol/dtypes"
"github.com/dydxprotocol/v4-chain/protocol/x/vault/types"
)

Expand All @@ -22,13 +23,19 @@ func (k Keeper) GetTotalShares(
return val, true
}

// SetTotalShares sets TotalShares for a vault.
// SetTotalShares sets TotalShares for a vault. Returns error if `totalShares` is negative.
func (k Keeper) SetTotalShares(
ctx sdk.Context,
vaultId types.VaultId,
totalShares types.NumShares,
) {
) error {
if totalShares.NumShares.Cmp(dtypes.NewInt(0)) == -1 {
return types.ErrNegativeShares
}

b := k.cdc.MustMarshal(&totalShares)
totalSharesStore := prefix.NewStore(ctx.KVStore(k.storeKey), []byte(types.TotalSharesKeyPrefix))
totalSharesStore.Set(vaultId.ToStateKey(), b)

return nil
}
26 changes: 24 additions & 2 deletions protocol/x/vault/keeper/shares_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,47 @@ func TestGetSetTotalShares(t *testing.T) {
require.Equal(t, false, exists)

// Set total shares for a vault and then get.
k.SetTotalShares(ctx, constants.Vault_Clob_0, types.NumShares{
err := k.SetTotalShares(ctx, constants.Vault_Clob_0, types.NumShares{
NumShares: dtypes.NewInt(7),
})
require.NoError(t, err)
numShares, exists := k.GetTotalShares(ctx, constants.Vault_Clob_0)
require.Equal(t, true, exists)
require.Equal(t, dtypes.NewInt(7), numShares.NumShares)

// Set total shares for another vault and then get.
k.SetTotalShares(ctx, constants.Vault_Clob_1, types.NumShares{
err = k.SetTotalShares(ctx, constants.Vault_Clob_1, types.NumShares{
NumShares: dtypes.NewInt(456),
})
require.NoError(t, err)
numShares, exists = k.GetTotalShares(ctx, constants.Vault_Clob_1)
require.Equal(t, true, exists)
require.Equal(t, dtypes.NewInt(456), numShares.NumShares)

// Set total shares for second vault to 0.
err = k.SetTotalShares(ctx, constants.Vault_Clob_1, types.NumShares{
NumShares: dtypes.NewInt(0),
})
require.NoError(t, err)
numShares, exists = k.GetTotalShares(ctx, constants.Vault_Clob_1)
require.Equal(t, true, exists)
require.Equal(t, dtypes.NewInt(0), numShares.NumShares)

// Set total shares for the first vault again and then get.
k.SetTotalShares(ctx, constants.Vault_Clob_0, types.NumShares{
NumShares: dtypes.NewInt(123),
})
require.NoError(t, err)
numShares, exists = k.GetTotalShares(ctx, constants.Vault_Clob_0)
require.Equal(t, true, exists)
require.Equal(t, dtypes.NewInt(123), numShares.NumShares)

// Set total shares for the first vault to a negative value.
// Should get error and total shares should remain unchanged.
err = k.SetTotalShares(ctx, constants.Vault_Clob_0, types.NumShares{
NumShares: dtypes.NewInt(-1),
})
require.Equal(t, types.ErrNegativeShares, err)
numShares, exists = k.GetTotalShares(ctx, constants.Vault_Clob_0)
require.Equal(t, true, exists)
require.Equal(t, dtypes.NewInt(123), numShares.NumShares)
Expand Down
13 changes: 13 additions & 0 deletions protocol/x/vault/types/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package types

// DONTCOVER

import errorsmod "cosmossdk.io/errors"

var (
ErrNegativeShares = errorsmod.Register(
ModuleName,
1,
"Shares are negative",
)
)

0 comments on commit 5614707

Please sign in to comment.