diff --git a/x/gamm/keeper/math.go b/x/gamm/keeper/math.go index 2132eca29ac..49fc880a16d 100644 --- a/x/gamm/keeper/math.go +++ b/x/gamm/keeper/math.go @@ -54,7 +54,35 @@ func calcSpotPriceWithSwapFee( return spotPrice.Mul(scale) } -// aO +// solveConstantFunctionInvariant solves the constant function of an AMM +// that determines the relationship between the differences of two sides +// of assets inside the pool. +// For fixed balanceXBefore, balanceXAfter, weightX, balanceY, weightY, +// we could deduce the balanceYDelta, calculated by: +// balanceYDelta = balanceY * (1 - (balanceXBefore/balanceXAfter)^(weightX/weightY)) +// balanceYDelta is positive when the balance liquidity decreases. +// balanceYDelta is negative when the balance liquidity increases. +func solveConstantFunctionInvariant( + tokenBalanceFixedBefore, + tokenBalanceFixedAfter, + tokenWeightFixed, + tokenBalanceUnknownBefore, + tokenWeightUnknown sdk.Dec, +) sdk.Dec { + // weightRatio = (weightX/weightY) + weightRatio := tokenWeightFixed.Quo(tokenWeightUnknown) + + // y = balanceXBefore/balanceYAfter + y := tokenBalanceFixedBefore.Quo(tokenBalanceFixedAfter) + + // amountY = balanceY * (1 - (y ^ weightRatio)) + foo := osmomath.Pow(y, weightRatio) + multiplier := sdk.OneDec().Sub(foo) + return tokenBalanceUnknownBefore.Mul(multiplier) +} + +// calcOutGivenIn calculates token to be swapped out given +// the provided amount, fee deducted, using solveConstantFunctionInvariant func calcOutGivenIn( tokenBalanceIn, tokenWeightIn, @@ -63,16 +91,15 @@ func calcOutGivenIn( tokenAmountIn, swapFee sdk.Dec, ) sdk.Dec { - weightRatio := tokenWeightIn.Quo(tokenWeightOut) - adjustedIn := sdk.OneDec().Sub(swapFee) - adjustedIn = tokenAmountIn.Mul(adjustedIn) - y := tokenBalanceIn.Quo(tokenBalanceIn.Add(adjustedIn)) - foo := osmomath.Pow(y, weightRatio) - bar := sdk.OneDec().Sub(foo) - return tokenBalanceOut.Mul(bar) + // deduct swapfee on the in asset + tokenAmountInAfterFee := tokenAmountIn.Mul(sdk.OneDec().Sub(swapFee)) + // delta balanceOut is positive(tokens inside the pool decreases) + tokenAmountOut := solveConstantFunctionInvariant(tokenBalanceIn, tokenBalanceIn.Add(tokenAmountInAfterFee), tokenWeightIn, tokenBalanceOut, tokenWeightOut) + return tokenAmountOut } -// aI +// calcInGivenOut calculates token to be provided, fee added, +// given the swapped out amount, using solveConstantFunctionInvariant func calcInGivenOut( tokenBalanceIn, tokenWeightIn, @@ -81,120 +108,96 @@ func calcInGivenOut( tokenAmountOut, swapFee sdk.Dec, ) sdk.Dec { - weightRatio := tokenWeightOut.Quo(tokenWeightIn) - diff := tokenBalanceOut.Sub(tokenAmountOut) - y := tokenBalanceOut.Quo(diff) - foo := osmomath.Pow(y, weightRatio) - foo = foo.Sub(one) - tokenAmountIn := sdk.OneDec().Sub(swapFee) - return (tokenBalanceIn.Mul(foo)).Quo(tokenAmountIn) + // delta balanceIn is negative(amount of tokens inside the pool increases) + tokenAmountIn := solveConstantFunctionInvariant(tokenBalanceOut, tokenBalanceOut.Sub(tokenAmountOut), tokenWeightOut, tokenBalanceIn, tokenWeightIn).Neg() + // We deduct a swap fee on the input asset. The swap happens by following the invariant curve on the input * (1 - swap fee) + // and then the swap fee is added to the pool. + // Thus in order to give X amount out, we solve the invariant for the invariant input. However invariant input = (1 - swapfee) * trade input. + // Therefore we divide by (1 - swapfee) here + tokenAmountInBeforeFee := tokenAmountIn.Quo(sdk.OneDec().Sub(swapFee)) + return tokenAmountInBeforeFee } -// pAo -func calcPoolOutGivenSingleIn( - tokenBalanceIn, - tokenWeightIn, - poolSupply, - totalWeight, - tokenAmountIn, +func feeRatio( + normalizedWeight, swapFee sdk.Dec, ) sdk.Dec { - normalizedWeight := tokenWeightIn.Quo(totalWeight) - zaz := (sdk.OneDec().Sub(normalizedWeight)).Mul(swapFee) - tokenAmountInAfterFee := tokenAmountIn.Mul(sdk.OneDec().Sub(zaz)) - - newTokenBalanceIn := tokenBalanceIn.Add(tokenAmountInAfterFee) - tokenInRatio := newTokenBalanceIn.Quo(tokenBalanceIn) - - // uint newPoolSupply = (ratioTi ^ weightTi) * poolSupply; - poolRatio := osmomath.Pow(tokenInRatio, normalizedWeight) - newPoolSupply := poolRatio.Mul(poolSupply) - return newPoolSupply.Sub(poolSupply) + zar := (sdk.OneDec().Sub(normalizedWeight)).Mul(swapFee) + return sdk.OneDec().Sub(zar) } -//tAi +// calcSingleInGivenPoolOut calculates token to be provided, fee added, +// given the swapped out shares amount, using solveConstantFunctionInvariant func calcSingleInGivenPoolOut( tokenBalanceIn, - tokenWeightIn, + normalizedTokenWeightIn, poolSupply, - totalWeight, poolAmountOut, swapFee sdk.Dec, ) sdk.Dec { - normalizedWeight := tokenWeightIn.Quo(totalWeight) - newPoolSupply := poolSupply.Add(poolAmountOut) - poolRatio := newPoolSupply.Quo(poolSupply) - - //uint newBalTi = poolRatio^(1/weightTi) * balTi; - boo := sdk.OneDec().Quo(normalizedWeight) - tokenInRatio := osmomath.Pow(poolRatio, boo) - newTokenBalanceIn := tokenInRatio.Mul(tokenBalanceIn) - tokenAmountInAfterFee := newTokenBalanceIn.Sub(tokenBalanceIn) - // Do reverse order of fees charged in joinswap_ExternAmountIn, this way - // ``` pAo == joinswap_ExternAmountIn(Ti, joinswap_PoolAmountOut(pAo, Ti)) ``` - //uint tAi = tAiAfterFee / (1 - (1-weightTi) * swapFee) ; - zar := (sdk.OneDec().Sub(normalizedWeight)).Mul(swapFee) - return tokenAmountInAfterFee.Quo(sdk.OneDec().Sub(zar)) + // delta balanceIn is negative(tokens inside the pool increases) + // pool weight is always 1 + tokenAmountIn := solveConstantFunctionInvariant(poolSupply.Add(poolAmountOut), poolSupply, sdk.OneDec(), tokenBalanceIn, normalizedTokenWeightIn).Neg() + // deduct swapfee on the in asset + tokenAmountInBeforeFee := tokenAmountIn.Quo(feeRatio(normalizedTokenWeightIn, swapFee)) + return tokenAmountInBeforeFee +} + +// pAo +func calcPoolOutGivenSingleIn( + tokenBalanceIn, + normalizedTokenWeightIn, + poolSupply, + tokenAmountIn, + swapFee sdk.Dec, +) sdk.Dec { + // deduct swapfee on the in asset + tokenAmountInAfterFee := tokenAmountIn.Mul(feeRatio(normalizedTokenWeightIn, swapFee)) + // delta poolSupply is negative(total pool shares increases) + // pool weight is always 1 + poolAmountOut := solveConstantFunctionInvariant(tokenBalanceIn.Add(tokenAmountInAfterFee), tokenBalanceIn, normalizedTokenWeightIn, poolSupply, sdk.OneDec()).Neg() + return poolAmountOut } // tAo func calcSingleOutGivenPoolIn( tokenBalanceOut, - tokenWeightOut, + normalizedTokenWeightOut, poolSupply, - totalWeight, poolAmountIn, swapFee sdk.Dec, exitFee sdk.Dec, ) sdk.Dec { - normalizedWeight := tokenWeightOut.Quo(totalWeight) // charge exit fee on the pool token side // pAiAfterExitFee = pAi*(1-exitFee) poolAmountInAfterExitFee := poolAmountIn.Mul(sdk.OneDec().Sub(exitFee)) - newPoolSupply := poolSupply.Sub(poolAmountInAfterExitFee) - poolRatio := newPoolSupply.Quo(poolSupply) - - // newBalTo = poolRatio^(1/weightTo) * balTo; - - tokenOutRatio := osmomath.Pow(poolRatio, sdk.OneDec().Quo(normalizedWeight)) - newTokenBalanceOut := tokenOutRatio.Mul(tokenBalanceOut) - tokenAmountOutBeforeSwapFee := tokenBalanceOut.Sub(newTokenBalanceOut) - - // charge swap fee on the output token side - //uint tAo = tAoBeforeSwapFee * (1 - (1-weightTo) * swapFee) - zaz := (sdk.OneDec().Sub(normalizedWeight)).Mul(swapFee) - tokenAmountOut := tokenAmountOutBeforeSwapFee.Mul(sdk.OneDec().Sub(zaz)) - return tokenAmountOut + // delta balanceOut is positive(tokens inside the pool decreases) + // pool weight is always 1 + tokenAmountOut := solveConstantFunctionInvariant(poolSupply.Sub(poolAmountInAfterExitFee), poolSupply, sdk.OneDec(), tokenBalanceOut, normalizedTokenWeightOut) + // deduct + tokenAmountOutAfterFee := tokenAmountOut.Mul(feeRatio(normalizedTokenWeightOut, swapFee)) + return tokenAmountOutAfterFee } // pAi func calcPoolInGivenSingleOut( tokenBalanceOut, - tokenWeightOut, + normalizedTokenWeightOut, poolSupply, - totalWeight, tokenAmountOut, swapFee sdk.Dec, exitFee sdk.Dec, ) sdk.Dec { - // charge swap fee on the output token side - normalizedWeight := tokenWeightOut.Quo(totalWeight) - //uint tAoBeforeSwapFee = tAo / (1 - (1-weightTo) * swapFee) ; - zoo := sdk.OneDec().Sub(normalizedWeight) - zar := zoo.Mul(swapFee) - tokenAmountOutBeforeSwapFee := tokenAmountOut.Quo(sdk.OneDec().Sub(zar)) - - newTokenBalanceOut := tokenBalanceOut.Sub(tokenAmountOutBeforeSwapFee) - tokenOutRatio := newTokenBalanceOut.Quo(tokenBalanceOut) + tokenAmountOutBeforeFee := tokenAmountOut.Quo(feeRatio(normalizedTokenWeightOut, swapFee)) - //uint newPoolSupply = (ratioTo ^ weightTo) * poolSupply; - poolRatio := osmomath.Pow(tokenOutRatio, normalizedWeight) - newPoolSupply := poolRatio.Mul(poolSupply) - poolAmountInAfterExitFee := poolSupply.Sub(newPoolSupply) + // delta poolSupply is positive(total pool shares decreases) + // pool weight is always 1 + poolAmountIn := solveConstantFunctionInvariant(tokenBalanceOut.Sub(tokenAmountOutBeforeFee), tokenBalanceOut, normalizedTokenWeightOut, poolSupply, sdk.OneDec()) // charge exit fee on the pool token side // pAi = pAiAfterExitFee/(1-exitFee) - return poolAmountInAfterExitFee.Quo(sdk.OneDec().Sub(exitFee)) + poolAmountInBeforeFee := poolAmountIn.Quo(sdk.OneDec().Sub(exitFee)) + return poolAmountInBeforeFee } diff --git a/x/gamm/keeper/math_test.go b/x/gamm/keeper/math_test.go index 0c400c05351..133569e4903 100644 --- a/x/gamm/keeper/math_test.go +++ b/x/gamm/keeper/math_test.go @@ -1,6 +1,7 @@ package keeper import ( + "math/rand" "testing" sdk "github.com/cosmos/cosmos-sdk/types" @@ -9,17 +10,9 @@ import ( ) func TestCalcSpotPrice(t *testing.T) { - // TODO: Change test to be table driven - tokenBalanceIn, err := sdk.NewDecFromStr("100") - require.NoError(t, err) - tokenWeightIn, err := sdk.NewDecFromStr("0.1") - require.NoError(t, err) - tokenBalanceOut, err := sdk.NewDecFromStr("200") - require.NoError(t, err) - tokenWeightOut, err := sdk.NewDecFromStr("0.3") - require.NoError(t, err) + tc := tc(t, "100", "0.1", "200", "0.3", "", "0", "0", "0") - actual_spot_price := calcSpotPrice(tokenBalanceIn, tokenWeightIn, tokenBalanceOut, tokenWeightOut) + actual_spot_price := calcSpotPrice(tc.tokenBalanceIn, tc.tokenWeightIn, tc.tokenBalanceOut, tc.tokenWeightOut) // s = (100/.1) / (200 / .3) = (1000) / (2000 / 3) = 1.5 expected_spot_price, err := sdk.NewDecFromStr("1.5") require.NoError(t, err) @@ -35,18 +28,9 @@ func TestCalcSpotPrice(t *testing.T) { // TODO: Create test vectors with balancer contract func TestCalcSpotPriceWithSwapFee(t *testing.T) { - tokenBalanceIn, err := sdk.NewDecFromStr("100") - require.NoError(t, err) - tokenWeightIn, err := sdk.NewDecFromStr("0.1") - require.NoError(t, err) - tokenBalanceOut, err := sdk.NewDecFromStr("200") - require.NoError(t, err) - tokenWeightOut, err := sdk.NewDecFromStr("0.3") - require.NoError(t, err) - swapFee, err := sdk.NewDecFromStr("0.01") - require.NoError(t, err) + tc := tc(t, "100", "0.1", "200", "0.3", "", "0", "0.01", "0") - s := calcSpotPriceWithSwapFee(tokenBalanceIn, tokenWeightIn, tokenBalanceOut, tokenWeightOut, swapFee) + s := calcSpotPriceWithSwapFee(tc.tokenBalanceIn, tc.tokenWeightIn, tc.tokenBalanceOut, tc.tokenWeightOut, tc.swapFee) expectedDec, err := sdk.NewDecFromStr("1.51515151") require.NoError(t, err) @@ -60,21 +44,12 @@ func TestCalcSpotPriceWithSwapFee(t *testing.T) { } func TestCalcOutGivenIn(t *testing.T) { + tc := tc(t, "100", "0.1", "200", "0.3", "", "0", "0.01", "0") - tokenBalanceIn, err := sdk.NewDecFromStr("100") - require.NoError(t, err) - tokenWeightIn, err := sdk.NewDecFromStr("0.1") - require.NoError(t, err) - tokenBalanceOut, err := sdk.NewDecFromStr("200") - require.NoError(t, err) - tokenWeightOut, err := sdk.NewDecFromStr("0.3") - require.NoError(t, err) tokenAmountIn, err := sdk.NewDecFromStr("40") require.NoError(t, err) - swapFee, err := sdk.NewDecFromStr("0.01") - require.NoError(t, err) - s := calcOutGivenIn(tokenBalanceIn, tokenWeightIn, tokenBalanceOut, tokenWeightOut, tokenAmountIn, swapFee) + s := tc.calcOutGivenIn(tokenAmountIn) expectedDec, err := sdk.NewDecFromStr("21.0487006") require.NoError(t, err) @@ -88,21 +63,11 @@ func TestCalcOutGivenIn(t *testing.T) { } func TestCalcInGivenOut(t *testing.T) { - - tokenBalanceIn, err := sdk.NewDecFromStr("100") - require.NoError(t, err) - tokenWeightIn, err := sdk.NewDecFromStr("0.1") - require.NoError(t, err) - tokenBalanceOut, err := sdk.NewDecFromStr("200") - require.NoError(t, err) - tokenWeightOut, err := sdk.NewDecFromStr("0.3") - require.NoError(t, err) + tc := tc(t, "100", "0.1", "200", "0.3", "", "0", "0.01", "0") tokenAmountOut, err := sdk.NewDecFromStr("70") require.NoError(t, err) - swapFee, err := sdk.NewDecFromStr("0.01") - require.NoError(t, err) - s := calcInGivenOut(tokenBalanceIn, tokenWeightIn, tokenBalanceOut, tokenWeightOut, tokenAmountOut, swapFee) + s := tc.calcInGivenOut(tokenAmountOut) expectedDec, err := sdk.NewDecFromStr("266.8009177") require.NoError(t, err) @@ -115,21 +80,12 @@ func TestCalcInGivenOut(t *testing.T) { } func TestCalcPoolOutGivenSingleIn(t *testing.T) { + tc := tc(t, "100", "0.2", "200", "0.8", "1", "300", "0.15", "0") - tokenBalanceIn, err := sdk.NewDecFromStr("100") - require.NoError(t, err) - tokenWeightIn, err := sdk.NewDecFromStr("0.2") - require.NoError(t, err) - poolSupply, err := sdk.NewDecFromStr("300") - require.NoError(t, err) - totalWeight, err := sdk.NewDecFromStr("1") - require.NoError(t, err) tokenAmountIn, err := sdk.NewDecFromStr("40") require.NoError(t, err) - swapFee, err := sdk.NewDecFromStr("0.15") - require.NoError(t, err) - s := calcPoolOutGivenSingleIn(tokenBalanceIn, tokenWeightIn, poolSupply, totalWeight, tokenAmountIn, swapFee) + s := tc.calcPoolOutGivenSingleIn(tokenAmountIn) expectedDec, err := sdk.NewDecFromStr("18.6519592") require.NoError(t, err) @@ -157,7 +113,8 @@ func TestCalcSingleInGivenPoolOut(t *testing.T) { swapFee, err := sdk.NewDecFromStr("0.15") require.NoError(t, err) - s := calcSingleInGivenPoolOut(tokenBalanceIn, tokenWeightIn, poolSupply, totalWeight, poolAmountOut, swapFee) + normalizedWeight := tokenWeightIn.Quo(totalWeight) + s := calcSingleInGivenPoolOut(tokenBalanceIn, normalizedWeight, poolSupply, poolAmountOut, swapFee) expectedDec, err := sdk.NewDecFromStr(".") require.NoError(t, err) @@ -171,21 +128,11 @@ func TestCalcSingleInGivenPoolOut(t *testing.T) { */ func TestCalcSingleOutGivenPoolIn(t *testing.T) { - - tokenBalanceOut, err := sdk.NewDecFromStr("200") - require.NoError(t, err) - tokenWeightOut, err := sdk.NewDecFromStr("0.8") - require.NoError(t, err) - poolSupply, err := sdk.NewDecFromStr("300") - require.NoError(t, err) - totalWeight, err := sdk.NewDecFromStr("1") - require.NoError(t, err) + tc := tc(t, "100", "0.2", "200", "0.8", "1", "300", "0.15", "0") poolAmountIn, err := sdk.NewDecFromStr("40") require.NoError(t, err) - swapFee, err := sdk.NewDecFromStr("0.15") - require.NoError(t, err) - s := calcSingleOutGivenPoolIn(tokenBalanceOut, tokenWeightOut, poolSupply, totalWeight, poolAmountIn, swapFee, sdk.ZeroDec()) + s := tc.calcSingleOutGivenPoolIn(poolAmountIn) expectedDec, err := sdk.NewDecFromStr("31.77534976") require.NoError(t, err) @@ -198,21 +145,12 @@ func TestCalcSingleOutGivenPoolIn(t *testing.T) { } func TestCalcPoolInGivenSingleOut(t *testing.T) { + tc := tc(t, "100", "0.2", "200", "0.8", "1", "300", "0.15", "0") - tokenBalanceOut, err := sdk.NewDecFromStr("200") - require.NoError(t, err) - tokenWeightOut, err := sdk.NewDecFromStr("0.8") - require.NoError(t, err) - poolSupply, err := sdk.NewDecFromStr("300") - require.NoError(t, err) - totalWeight, err := sdk.NewDecFromStr("1") - require.NoError(t, err) tokenAmountOut, err := sdk.NewDecFromStr("70") require.NoError(t, err) - swapFee, err := sdk.NewDecFromStr("0.15") - require.NoError(t, err) - s := calcPoolInGivenSingleOut(tokenBalanceOut, tokenWeightOut, poolSupply, totalWeight, tokenAmountOut, swapFee, sdk.ZeroDec()) + s := tc.calcPoolInGivenSingleOut(tokenAmountOut) expectedDec, err := sdk.NewDecFromStr("90.29092777") require.NoError(t, err) @@ -223,3 +161,120 @@ func TestCalcPoolInGivenSingleOut(t *testing.T) { "expected value & actual value's difference should less than precision*10000", ) } + +type testCase struct { + tokenBalanceIn, tokenWeightIn sdk.Dec + tokenBalanceOut, tokenWeightOut sdk.Dec + totalWeight sdk.Dec + poolSupply sdk.Dec + swapFee, exitFee sdk.Dec +} + +func (tc testCase) reverse() testCase { + return testCase{ + tc.tokenBalanceOut, tc.tokenWeightOut, + tc.tokenBalanceIn, tc.tokenWeightIn, + tc.totalWeight, + tc.poolSupply, + tc.swapFee, tc.exitFee, + } +} + +func tc(t *testing.T, tokenBalanceIn, tokenWeightIn, tokenBalanceOut, tokenWeightOut, totalWeight, poolSupply, swapFee, exitFee string) (res testCase) { + var err error + res.tokenBalanceIn, err = sdk.NewDecFromStr(tokenBalanceIn) + require.NoError(t, err) + res.tokenWeightIn, err = sdk.NewDecFromStr(tokenWeightIn) + require.NoError(t, err) + res.tokenBalanceOut, err = sdk.NewDecFromStr(tokenBalanceOut) + require.NoError(t, err) + res.tokenWeightOut, err = sdk.NewDecFromStr(tokenWeightOut) + require.NoError(t, err) + if totalWeight == "" { + res.totalWeight = res.tokenWeightIn.Add(res.tokenWeightOut) + } else { + res.totalWeight, err = sdk.NewDecFromStr(totalWeight) + } + require.NoError(t, err) + res.poolSupply, err = sdk.NewDecFromStr(poolSupply) + require.NoError(t, err) + res.swapFee, err = sdk.NewDecFromStr(swapFee) + require.NoError(t, err) + res.exitFee, err = sdk.NewDecFromStr(exitFee) + require.NoError(t, err) + + return +} + +func randtc(t *testing.T, swapFee, exitFee sdk.Dec) (res testCase) { + res.tokenBalanceIn = sdk.NewInt(rand.Int63()).ToDec() + res.tokenWeightIn = sdk.NewInt(rand.Int63n(90) + 10).ToDec() + res.tokenBalanceOut = sdk.NewInt(rand.Int63()).ToDec() + res.tokenWeightOut = sdk.NewInt(rand.Int63n(90) + 10).ToDec() + res.totalWeight = res.tokenWeightIn.Add(res.tokenWeightOut) + res.poolSupply = sdk.NewInt(rand.Int63()).ToDec() + res.swapFee = swapFee + res.exitFee = exitFee + return +} + +func (tc testCase) calcInGivenOut(amount sdk.Dec) sdk.Dec { + return calcInGivenOut(tc.tokenBalanceIn, tc.tokenWeightIn, tc.tokenBalanceOut, tc.tokenWeightOut, amount, tc.swapFee) +} + +func (tc testCase) calcOutGivenIn(amount sdk.Dec) sdk.Dec { + return calcOutGivenIn(tc.tokenBalanceIn, tc.tokenWeightIn, tc.tokenBalanceOut, tc.tokenWeightOut, amount, tc.swapFee) +} + +func (tc testCase) calcPoolOutGivenSingleIn(amount sdk.Dec) sdk.Dec { + return calcPoolOutGivenSingleIn(tc.tokenBalanceIn, tc.tokenWeightIn.Quo(tc.totalWeight), tc.poolSupply, amount, tc.swapFee) +} + +func (tc testCase) calcPoolInGivenSingleOut(amount sdk.Dec) sdk.Dec { + return calcPoolInGivenSingleOut(tc.tokenBalanceOut, tc.tokenWeightOut.Quo(tc.totalWeight), tc.poolSupply, amount, tc.swapFee, tc.exitFee) +} + +func (tc testCase) calcSingleInGivenPoolOut(amount sdk.Dec) sdk.Dec { + return calcSingleInGivenPoolOut(tc.tokenBalanceIn, tc.tokenWeightIn.Quo(tc.totalWeight), tc.poolSupply, amount, tc.swapFee) +} + +func (tc testCase) calcSingleOutGivenPoolIn(amount sdk.Dec) sdk.Dec { + return calcSingleOutGivenPoolIn(tc.tokenBalanceOut, tc.tokenWeightOut.Quo(tc.totalWeight), tc.poolSupply, amount, tc.swapFee, tc.exitFee) +} + +func equalWithError(t *testing.T, x, y sdk.Dec, precision int64) { + require.True(t, x.Quo(y).Sub(sdk.OneDec()).Abs().LTE(sdk.OneDec().Quo(sdk.NewInt(precision).ToDec())), + "Not equal within error margin with difference %s: %s, %s", x.Quo(y).Sub(sdk.OneDec()), x, y) +} + +func TestCalcInverseInvariant(t *testing.T) { + tcs := make([]testCase, 10000) + for i := range tcs { + tcs[i] = randtc(t, sdk.NewInt(rand.Int63n(100)).ToDec().Quo(sdk.NewInt(1000).ToDec()), sdk.NewInt(rand.Int63n(100)).ToDec().Quo(sdk.NewInt(500).ToDec())) + } + + for _, tc := range tcs { + for i := 0; i < 10; i++ { + amount := sdk.NewInt(rand.Int63n(tc.tokenBalanceIn.TruncateInt().Int64() / 20)).ToDec() + + { + amountOut := tc.calcOutGivenIn(amount) + amount2 := tc.calcInGivenOut(amountOut) + equalWithError(t, amount, amount2, 100000) + } + + { + shareOut := tc.calcPoolOutGivenSingleIn(amount) + amount2 := tc.calcSingleInGivenPoolOut(shareOut) + equalWithError(t, amount, amount2, 100000) + } + + { + amountOut := sdk.NewInt(rand.Int63n(tc.tokenBalanceOut.TruncateInt().Int64() / 20)).ToDec() + shareIn := tc.calcPoolInGivenSingleOut(amountOut) + amount2 := tc.calcSingleOutGivenPoolIn(shareIn) + equalWithError(t, amountOut, amount2, 100000) + } + } + } +} diff --git a/x/gamm/keeper/pool_service.go b/x/gamm/keeper/pool_service.go index 939c520809d..ea7f787bdf5 100644 --- a/x/gamm/keeper/pool_service.go +++ b/x/gamm/keeper/pool_service.go @@ -191,11 +191,11 @@ func (k Keeper) JoinSwapExternAmountIn( return sdk.Int{}, err } + normalizedWeight := PoolAsset.Weight.ToDec().Quo(pool.GetTotalWeight().ToDec()) shareOutAmount = calcPoolOutGivenSingleIn( PoolAsset.Token.Amount.ToDec(), - PoolAsset.Weight.ToDec(), + normalizedWeight, pool.GetTotalShares().Amount.ToDec(), - pool.GetTotalWeight().ToDec(), tokenIn.Amount.ToDec(), pool.GetPoolSwapFee(), ).TruncateInt() @@ -259,11 +259,11 @@ func (k Keeper) JoinSwapShareAmountOut( return sdk.Int{}, err } + normalizedWeight := PoolAsset.Weight.ToDec().Quo(pool.GetTotalWeight().ToDec()) tokenInAmount = calcSingleInGivenPoolOut( PoolAsset.Token.Amount.ToDec(), - PoolAsset.Weight.ToDec(), + normalizedWeight, pool.GetTotalShares().Amount.ToDec(), - pool.GetTotalWeight().ToDec(), shareOutAmount.ToDec(), pool.GetPoolSwapFee(), ).TruncateInt() @@ -411,11 +411,11 @@ func (k Keeper) ExitSwapShareAmountIn( return sdk.Int{}, err } + normalizedWeight := PoolAsset.Weight.ToDec().Quo(pool.GetTotalWeight().ToDec()) tokenOutAmount = calcSingleOutGivenPoolIn( PoolAsset.Token.Amount.ToDec(), - PoolAsset.Weight.ToDec(), + normalizedWeight, pool.GetTotalShares().Amount.ToDec(), - pool.GetTotalWeight().ToDec(), shareInAmount.ToDec(), pool.GetPoolSwapFee(), pool.GetPoolExitFee(), @@ -496,11 +496,11 @@ func (k Keeper) ExitSwapExternAmountOut( return sdk.Int{}, err } + normalizedWeight := PoolAsset.Weight.ToDec().Quo(pool.GetTotalWeight().ToDec()) shareInAmount = calcPoolInGivenSingleOut( PoolAsset.Token.Amount.ToDec(), - PoolAsset.Weight.ToDec(), + normalizedWeight, pool.GetTotalShares().Amount.ToDec(), - pool.GetTotalWeight().ToDec(), tokenOut.Amount.ToDec(), pool.GetPoolSwapFee(), pool.GetPoolExitFee(), diff --git a/x/gamm/keeper/pool_test.go b/x/gamm/keeper/pool_test.go index f6476c7452c..28aa04e4737 100644 --- a/x/gamm/keeper/pool_test.go +++ b/x/gamm/keeper/pool_test.go @@ -124,7 +124,7 @@ func (suite *KeeperTestSuite) TestCleanupPoolRandomized() { for _, coin := range coinOf[acc.String()] { amt := suite.app.BankKeeper.GetBalance(suite.ctx, acc, coin.Denom) // the refund could have rounding error - suite.True(amt.Amount.Equal(coin.Amount) || amt.Amount.Equal(coin.Amount.SubRaw(1)), + suite.True(amt.Amount.Sub(coin.Amount).Abs().LTE(sdk.NewInt(2)), "Expected equal %s: %d, %d", amt.Denom, amt.Amount.Int64(), coin.Amount.Int64()) } }