Skip to content

Commit

Permalink
[stableswap]: Cap number of assets and post-scaled asset amounts to e…
Browse files Browse the repository at this point in the history
…nsure pools never overflow (#3055)

* add tests for 10-asset pools with 10B per asset

* add max post-scaled asset check and create pool tests

* add sanity tests for new swap guardrails

* move max scaled asset amt to constant

* add join-pool-internal tests for new functionality
  • Loading branch information
AlpinYukseloglu authored Oct 24, 2022
1 parent 8590b80 commit ced84c1
Show file tree
Hide file tree
Showing 8 changed files with 390 additions and 22 deletions.
28 changes: 20 additions & 8 deletions x/gamm/pool-models/stableswap/amm.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,11 +377,14 @@ func (p *Pool) calcSingleAssetJoinShares(tokenIn sdk.Coin, swapFee sdk.Dec) (sdk
// We can mutate pa here
// TODO: some day switch this to a COW wrapped pa, for better perf
func (p *Pool) joinPoolSharesInternal(ctx sdk.Context, tokensIn sdk.Coins, swapFee sdk.Dec) (numShares sdk.Int, newLiquidity sdk.Coins, err error) {
if len(tokensIn) == 1 {
if !tokensIn.DenomsSubsetOf(p.GetTotalPoolLiquidity(ctx)) {
return sdk.ZeroInt(), sdk.NewCoins(), errors.New("attempted joining pool with assets that do not exist in pool")
}
if len(tokensIn) == 1 && tokensIn[0].Amount.GT(sdk.OneInt()) {
numShares, err = p.calcSingleAssetJoinShares(tokensIn[0], swapFee)
newLiquidity = tokensIn
return numShares, newLiquidity, err
} else if len(tokensIn) != p.NumAssets() || !tokensIn.DenomsSubsetOf(p.GetTotalPoolLiquidity(ctx)) {
} else if len(tokensIn) != p.NumAssets() {
return sdk.ZeroInt(), sdk.NewCoins(), errors.New(
"stableswap pool only supports LP'ing with one asset, or all assets in pool")
}
Expand All @@ -393,15 +396,24 @@ func (p *Pool) joinPoolSharesInternal(ctx sdk.Context, tokensIn sdk.Coins, swapF
}
p.updatePoolForJoin(tokensIn.Sub(remCoins), numShares)

tokensJoined := tokensIn
for _, coin := range remCoins {
// TODO: Perhaps add a method to skip if this is too small.
newShare, err := p.calcSingleAssetJoinShares(coin, swapFee)
if err != nil {
return sdk.ZeroInt(), sdk.NewCoins(), err
if coin.Amount.GT(sdk.OneInt()) {
newShare, err := p.calcSingleAssetJoinShares(coin, swapFee)
if err != nil {
return sdk.ZeroInt(), sdk.NewCoins(), err
}
p.updatePoolForJoin(sdk.NewCoins(coin), newShare)
numShares = numShares.Add(newShare)
} else {
tokensJoined = tokensJoined.Sub(sdk.NewCoins(coin))
}
p.updatePoolForJoin(sdk.NewCoins(coin), newShare)
numShares = numShares.Add(newShare)
}

return numShares, tokensIn, nil
if err = validatePoolAssets(p.PoolLiquidity, p.ScalingFactor); err != nil {
return sdk.ZeroInt(), sdk.NewCoins(), err
}

return numShares, tokensJoined, nil
}
146 changes: 140 additions & 6 deletions x/gamm/pool-models/stableswap/amm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,16 +224,38 @@ var (
yIn: osmomath.NewBigDec(1),
expectPanic: false,
},
/* TODO: increase BigDec precision (36 -> 72) to be able to accommodate this
"even 4-asset large pool, small input": {
"even 4-asset large pool (100M each), small input": {
xReserve: osmomath.NewBigDec(100000000),
yReserve: osmomath.NewBigDec(100000000),
// represents a 4-asset pool with 100M in each reserve
remReserves: []osmomath.BigDec{osmomath.NewBigDec(100000000), osmomath.NewBigDec(100000000)},
yIn: osmomath.NewBigDec(100),
yIn: osmomath.NewBigDec(100),
expectPanic: false,
},
"even 4-asset pool (10B each post-scaled), small input": {
xReserve: osmomath.NewBigDec(10000000000),
yReserve: osmomath.NewBigDec(10000000000),
// represents a 4-asset pool with 10B in each reserve
remReserves: []osmomath.BigDec{osmomath.NewBigDec(10000000000), osmomath.NewBigDec(10000000000)},
yIn: osmomath.NewBigDec(100000000),
expectPanic: false,
},
"even 10-asset pool (10B each post-scaled), small input": {
xReserve: osmomath.NewBigDec(10_000_000_000),
yReserve: osmomath.NewBigDec(10_000_000_000),
// represents a 10-asset pool with 10B in each reserve
remReserves: []osmomath.BigDec{osmomath.NewBigDec(10_000_000_000), osmomath.NewBigDec(10_000_000_000), osmomath.NewBigDec(10_000_000_000), osmomath.NewBigDec(10_000_000_000), osmomath.NewBigDec(10_000_000_000), osmomath.NewBigDec(10_000_000_000), osmomath.NewBigDec(10_000_000_000), osmomath.NewBigDec(10_000_000_000)},
yIn: osmomath.NewBigDec(100),
expectPanic: false,
},
"even 10-asset pool (100B each post-scaled), large input": {
xReserve: osmomath.NewBigDec(100_000_000_000),
yReserve: osmomath.NewBigDec(100_000_000_000),
// represents a 10-asset pool with 100B in each reserve
remReserves: []osmomath.BigDec{osmomath.NewBigDec(100_000_000_000), osmomath.NewBigDec(100_000_000_000), osmomath.NewBigDec(100_000_000_000), osmomath.NewBigDec(100_000_000_000), osmomath.NewBigDec(100_000_000_000), osmomath.NewBigDec(100_000_000_000), osmomath.NewBigDec(100_000_000_000), osmomath.NewBigDec(100_000_000_000)},
yIn: osmomath.NewBigDec(10_000_000_000),
expectPanic: false,
},
*/

// uneven pools
"uneven 3-asset pool, even swap assets as pool minority": {
Expand Down Expand Up @@ -790,8 +812,6 @@ func TestCalcSingleAssetJoinShares(t *testing.T) {
swapFee: sdk.MustNewDecFromStr("0.03"),
expectedOut: sdk.NewInt(100 - 3),
},

// TODO: increase BigDec precision further to be able to accommodate 5-asset pool tests
}

for name, tc := range tests {
Expand All @@ -815,3 +835,117 @@ func TestCalcSingleAssetJoinShares(t *testing.T) {
})
}
}

func TestJoinPoolSharesInternal(t *testing.T) {
tenPercentOfTwoPoolRaw := int64(1000000000 / 10)
tenPercentOfTwoPoolCoins := sdk.NewCoins(sdk.NewCoin("foo", sdk.NewInt(int64(1000000000/10))), sdk.NewCoin("bar", sdk.NewInt(int64(1000000000/10))))
twoAssetPlusTenPercent := twoEvenStablePoolAssets.Add(tenPercentOfTwoPoolCoins...)
type testcase struct {
tokensIn sdk.Coins
poolAssets sdk.Coins
scalingFactors []uint64
swapFee sdk.Dec
expNumShare sdk.Int
expTokensJoined sdk.Coins
expPoolAssets sdk.Coins
expectPass bool
}

tests := map[string]testcase{
"even two asset pool, same tokenIn ratio": {
tokensIn: tenPercentOfTwoPoolCoins,
poolAssets: twoEvenStablePoolAssets,
scalingFactors: defaultTwoAssetScalingFactors,
swapFee: sdk.ZeroDec(),
expNumShare: sdk.NewIntFromUint64(10000000000000000000),
expTokensJoined: tenPercentOfTwoPoolCoins,
expPoolAssets: twoAssetPlusTenPercent,
expectPass: true,
},
"even two asset pool, different tokenIn ratio with pool": {
tokensIn: sdk.NewCoins(sdk.NewCoin("foo", sdk.NewInt(tenPercentOfTwoPoolRaw)), sdk.NewCoin("bar", sdk.NewInt(10+tenPercentOfTwoPoolRaw))),
poolAssets: twoEvenStablePoolAssets,
scalingFactors: defaultTwoAssetScalingFactors,
swapFee: sdk.ZeroDec(),
expNumShare: sdk.NewIntFromUint64(10000000500000000000),
expTokensJoined: sdk.NewCoins(sdk.NewCoin("foo", sdk.NewInt(tenPercentOfTwoPoolRaw)), sdk.NewCoin("bar", sdk.NewInt(10+tenPercentOfTwoPoolRaw))),
expPoolAssets: twoAssetPlusTenPercent.Add(sdk.NewCoin("bar", sdk.NewInt(10))),
expectPass: true,
},
"all-asset pool join attempt exceeds max scaled asset amount": {
tokensIn: sdk.NewCoins(
sdk.NewInt64Coin("foo", 1),
sdk.NewInt64Coin("bar", 1),
),
poolAssets: sdk.NewCoins(
sdk.NewInt64Coin("foo", 10_000_000_000),
sdk.NewInt64Coin("bar", 10_000_000_000),
),
scalingFactors: defaultTwoAssetScalingFactors,
swapFee: sdk.ZeroDec(),
expNumShare: sdk.ZeroInt(),
expTokensJoined: sdk.Coins{},
expPoolAssets: sdk.NewCoins(
sdk.NewInt64Coin("foo", 10_000_000_000),
sdk.NewInt64Coin("bar", 10_000_000_000),
),
expectPass: false,
},
"single-asset pool join exceeds hits max scaled asset amount": {
tokensIn: sdk.NewCoins(
sdk.NewInt64Coin("foo", 1),
),
poolAssets: sdk.NewCoins(
sdk.NewInt64Coin("foo", 10_000_000_000),
sdk.NewInt64Coin("bar", 10_000_000_000),
),
scalingFactors: defaultTwoAssetScalingFactors,
swapFee: sdk.ZeroDec(),
expNumShare: sdk.ZeroInt(),
expTokensJoined: sdk.Coins{},
expPoolAssets: sdk.NewCoins(
sdk.NewInt64Coin("foo", 10_000_000_000),
sdk.NewInt64Coin("bar", 10_000_000_000),
),
expectPass: false,
},
"all-asset pool join attempt exactly hits max scaled asset amount": {
tokensIn: sdk.NewCoins(
sdk.NewInt64Coin("foo", 1),
sdk.NewInt64Coin("bar", 1),
),
poolAssets: sdk.NewCoins(
sdk.NewInt64Coin("foo", 9_999_999_999),
sdk.NewInt64Coin("bar", 9_999_999_999),
),
scalingFactors: defaultTwoAssetScalingFactors,
swapFee: sdk.ZeroDec(),
expNumShare: sdk.NewInt(10000000000),
expTokensJoined: sdk.NewCoins(
sdk.NewInt64Coin("foo", 1),
sdk.NewInt64Coin("bar", 1),
),
expPoolAssets: sdk.NewCoins(
sdk.NewInt64Coin("foo", 10_000_000_000),
sdk.NewInt64Coin("bar", 10_000_000_000),
),
expectPass: true,
},
}

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
ctx := sdk.Context{}
p := poolStructFromAssets(tc.poolAssets, tc.scalingFactors)

shares, joinedLiquidity, err := p.joinPoolSharesInternal(ctx, tc.tokensIn, tc.swapFee)

if tc.expectPass {
require.Equal(t, tc.expNumShare, shares)
require.Equal(t, tc.expTokensJoined, joinedLiquidity)
require.Equal(t, tc.expPoolAssets, p.PoolLiquidity)
}
osmoassert.ConditionalError(t, !tc.expectPass, err)
})
}
}
13 changes: 6 additions & 7 deletions x/gamm/pool-models/stableswap/msgs.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,6 @@ func (msg MsgCreateStableswapPool) ValidateBasic() error {
return err
}

// validation for pool initial liquidity
if len(msg.InitialPoolLiquidity) < 2 {
return types.ErrTooFewPoolAssets
} else if len(msg.InitialPoolLiquidity) > 8 {
return types.ErrTooManyPoolAssets
}

// validation for scaling factors
// The message's scaling factors must be empty or a valid set of scaling factors
if len(msg.ScalingFactors) != 0 {
Expand All @@ -61,6 +54,12 @@ func (msg MsgCreateStableswapPool) ValidateBasic() error {
}
}

// validation for pool initial liquidity
// The message's pool liquidity must have between 2 and 8 assets with at most 10B post-scaled units in each
if err = validatePoolAssets(msg.InitialPoolLiquidity, msg.ScalingFactors); err != nil {
return err
}

// validation for scaling factor owner
if err = validateScalingFactorController(msg.ScalingFactorController); err != nil {
return err
Expand Down
54 changes: 54 additions & 0 deletions x/gamm/pool-models/stableswap/msgs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,60 @@ func TestMsgCreateStableswapPool(t *testing.T) {
}),
expectPass: true,
},
{
name: "max asset amounts",
msg: createMsg(func(msg stableswap.MsgCreateStableswapPool) stableswap.MsgCreateStableswapPool {
msg.InitialPoolLiquidity = sdk.Coins{
sdk.NewCoin("osmo", sdk.NewInt(10_000_000_000)),
sdk.NewCoin("atom", sdk.NewInt(10_000_000_000)),
sdk.NewCoin("usdt", sdk.NewInt(10_000_000_000)),
sdk.NewCoin("usdc", sdk.NewInt(10_000_000_000)),
sdk.NewCoin("juno", sdk.NewInt(10_000_000_000)),
sdk.NewCoin("akt", sdk.NewInt(10_000_000_000)),
sdk.NewCoin("regen", sdk.NewInt(10_000_000_000)),
sdk.NewCoin("band", sdk.NewInt(10_000_000_000)),
}
msg.ScalingFactors = []uint64{1, 1, 1, 1, 1, 1, 1, 1}
return msg
}),
expectPass: true,
},
{
name: "greater than max post-scaled amount with regular scaling factors",
msg: createMsg(func(msg stableswap.MsgCreateStableswapPool) stableswap.MsgCreateStableswapPool {
msg.InitialPoolLiquidity = sdk.Coins{
sdk.NewCoin("osmo", sdk.NewInt(1+10_000_000_000)),
sdk.NewCoin("atom", sdk.NewInt(10_000_000_000)),
sdk.NewCoin("usdt", sdk.NewInt(10_000_000_000)),
sdk.NewCoin("usdc", sdk.NewInt(10_000_000_000)),
sdk.NewCoin("juno", sdk.NewInt(10_000_000_000)),
sdk.NewCoin("akt", sdk.NewInt(10_000_000_000)),
sdk.NewCoin("regen", sdk.NewInt(10_000_000_000)),
sdk.NewCoin("band", sdk.NewInt(10_000_000_000)),
}
msg.ScalingFactors = []uint64{1, 1, 1, 1, 1, 1, 1, 1}
return msg
}),
expectPass: false,
},
{
name: "100B token 8-asset pool using large scaling factors",
msg: createMsg(func(msg stableswap.MsgCreateStableswapPool) stableswap.MsgCreateStableswapPool {
msg.InitialPoolLiquidity = sdk.Coins{
sdk.NewCoin("osmo", sdk.NewInt(100_000_000_000_000_000)),
sdk.NewCoin("atom", sdk.NewInt(100_000_000_000_000_000)),
sdk.NewCoin("usdt", sdk.NewInt(100_000_000_000_000_000)),
sdk.NewCoin("usdc", sdk.NewInt(100_000_000_000_000_000)),
sdk.NewCoin("juno", sdk.NewInt(100_000_000_000_000_000)),
sdk.NewCoin("akt", sdk.NewInt(100_000_000_000_000_000)),
sdk.NewCoin("regen", sdk.NewInt(100_000_000_000_000_000)),
sdk.NewCoin("band", sdk.NewInt(100_000_000_000_000_000)),
}
msg.ScalingFactors = []uint64{10000000, 10000000, 10000000, 10000000, 10000000, 10000000, 10000000, 10000000}
return msg
}),
expectPass: true,
},
}

for _, test := range tests {
Expand Down
32 changes: 32 additions & 0 deletions x/gamm/pool-models/stableswap/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ func NewStableswapPool(poolId uint64,
return Pool{}, err
}

if err := validatePoolAssets(initialLiquidity, scalingFactors); err != nil {
return Pool{}, err
}

pool := Pool{
Address: types.NewPoolAddress(poolId).String(),
Id: poolId,
Expand Down Expand Up @@ -229,6 +233,10 @@ func (p Pool) CalcOutAmtGivenIn(ctx sdk.Context, tokenIn sdk.Coins, tokenOutDeno
}

func (p *Pool) SwapOutAmtGivenIn(ctx sdk.Context, tokenIn sdk.Coins, tokenOutDenom string, swapFee sdk.Dec) (tokenOut sdk.Coin, err error) {
if err = validatePoolAssets(p.PoolLiquidity.Add(tokenIn...), p.ScalingFactor); err != nil {
return sdk.Coin{}, err
}

tokenOut, err = p.CalcOutAmtGivenIn(ctx, tokenIn, tokenOutDenom, swapFee)
if err != nil {
return sdk.Coin{}, err
Expand Down Expand Up @@ -265,6 +273,10 @@ func (p *Pool) SwapInAmtGivenOut(ctx sdk.Context, tokenOut sdk.Coins, tokenInDen
return sdk.Coin{}, err
}

if err = validatePoolAssets(p.PoolLiquidity.Add(tokenIn), p.ScalingFactor); err != nil {
return sdk.Coin{}, err
}

p.updatePoolLiquidityForSwap(sdk.NewCoins(tokenIn), tokenOut)

return tokenIn, nil
Expand Down Expand Up @@ -328,6 +340,10 @@ func (p *Pool) SetStableSwapScalingFactors(ctx sdk.Context, scalingFactors []uin
return err
}

if err := validatePoolAssets(p.PoolLiquidity, scalingFactors); err != nil {
return err
}

p.ScalingFactor = scalingFactors
return nil
}
Expand All @@ -353,3 +369,19 @@ func validateScalingFactors(scalingFactors []uint64, numAssets int) error {

return nil
}

func validatePoolAssets(initialAssets sdk.Coins, scalingFactors []uint64) error {
if len(initialAssets) < types.MinPoolAssets {
return types.ErrTooFewPoolAssets
} else if len(initialAssets) > types.MaxPoolAssets {
return types.ErrTooManyPoolAssets
}

for i, asset := range initialAssets {
if asset.Amount.Quo(sdk.NewInt(int64(scalingFactors[i]))).GT(sdk.NewInt(types.StableswapMaxScaledAmtPerAsset)) {
return types.ErrHitMaxScaledAssets
}
}

return nil
}
Loading

0 comments on commit ced84c1

Please sign in to comment.