forked from dydxprotocol/v4-chain
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CT-950] safety heap methods (dydxprotocol#1821)
- Loading branch information
Showing
2 changed files
with
375 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
package keeper | ||
|
||
import ( | ||
"cosmossdk.io/store/prefix" | ||
sdk "github.com/cosmos/cosmos-sdk/types" | ||
"github.com/dydxprotocol/v4-chain/protocol/x/subaccounts/types" | ||
) | ||
|
||
// RemoveSubaccountFromSafetyHeap removes a subaccount from the safety heap | ||
// given a peretual and side. | ||
func (k Keeper) RemoveSubaccountFromSafetyHeap( | ||
ctx sdk.Context, | ||
subaccountId types.SubaccountId, | ||
perpetualId uint32, | ||
side types.SafetyHeapPositionSide, | ||
) { | ||
store := k.GetSafetyHeapStore(ctx, perpetualId, side) | ||
index := k.MustGetSubaccountHeapIndex(store, subaccountId) | ||
k.MustRemoveElementAtIndex(ctx, store, index) | ||
} | ||
|
||
// AddSubaccountToSafetyHeap adds a subaccount to the safety heap | ||
// given a perpetual and side. | ||
func (k Keeper) AddSubaccountToSafetyHeap( | ||
ctx sdk.Context, | ||
subaccountId types.SubaccountId, | ||
perpetualId uint32, | ||
side types.SafetyHeapPositionSide, | ||
) { | ||
store := k.GetSafetyHeapStore(ctx, perpetualId, side) | ||
k.Insert(ctx, store, subaccountId) | ||
} | ||
|
||
// Heap methods | ||
|
||
// Insert inserts a subaccount into the safety heap. | ||
func (k Keeper) Insert( | ||
ctx sdk.Context, | ||
store prefix.Store, | ||
subaccountId types.SubaccountId, | ||
) { | ||
// Add the subaccount to the end of the heap. | ||
length := k.GetSafetyHeapLength(store) | ||
k.SetSubaccountAtIndex(store, subaccountId, length) | ||
|
||
// Increment the size of the heap. | ||
k.SetSafetyHeapLength(store, length+1) | ||
|
||
// Heapify up the element at the end of the heap | ||
// to restore the heap property. | ||
k.HeapifyUp(ctx, store, length) | ||
} | ||
|
||
// MustRemoveElementAtIndex removes the element at the given index | ||
// from the safety heap. | ||
func (k Keeper) MustRemoveElementAtIndex( | ||
ctx sdk.Context, | ||
store prefix.Store, | ||
index uint32, | ||
) { | ||
length := k.GetSafetyHeapLength(store) | ||
if index >= length { | ||
panic(types.ErrSafetyHeapSubaccountNotFoundAtIndex) | ||
} | ||
|
||
// Swap the element with the last element. | ||
k.Swap(store, index, length-1) | ||
|
||
// Remove the last element. | ||
k.DeleteSubaccountAtIndex(store, length-1) | ||
k.SetSafetyHeapLength(store, length-1) | ||
|
||
// Heapify down and up the element at the given index | ||
// to restore the heap property. | ||
if index < length-1 { | ||
k.HeapifyDown(ctx, store, index) | ||
k.HeapifyUp(ctx, store, index) | ||
} | ||
} | ||
|
||
// HeapifyUp moves the element at the given index up the heap | ||
// until the heap property is restored. | ||
func (k Keeper) HeapifyUp( | ||
ctx sdk.Context, | ||
store prefix.Store, | ||
index uint32, | ||
) { | ||
if index == 0 { | ||
return | ||
} | ||
|
||
parentIndex := (index - 1) / 2 | ||
if k.Less(ctx, store, index, parentIndex) { | ||
k.Swap(store, index, parentIndex) | ||
k.HeapifyUp(ctx, store, parentIndex) | ||
} | ||
} | ||
|
||
// HeapifyDown moves the element at the given index down the heap | ||
// until the heap property is restored. | ||
func (k Keeper) HeapifyDown( | ||
ctx sdk.Context, | ||
store prefix.Store, | ||
index uint32, | ||
) { | ||
leftIndex, rightIndex := 2*index+1, 2*index+2 | ||
|
||
length := k.GetSafetyHeapLength(store) | ||
if rightIndex < length && k.Less(ctx, store, rightIndex, leftIndex) { | ||
// Compare the current node with the right child | ||
// if right child exists and is less than the left child. | ||
if k.Less(ctx, store, rightIndex, index) { | ||
k.Swap(store, index, rightIndex) | ||
k.HeapifyDown(ctx, store, rightIndex) | ||
} | ||
} else if leftIndex < length { | ||
// Compare the current node with the left child | ||
// if left child exists. | ||
if k.Less(ctx, store, leftIndex, index) { | ||
k.Swap(store, index, leftIndex) | ||
k.HeapifyDown(ctx, store, leftIndex) | ||
} | ||
} | ||
} | ||
|
||
// Swap swaps the elements at the given indices. | ||
func (k Keeper) Swap( | ||
store prefix.Store, | ||
index1 uint32, | ||
index2 uint32, | ||
) { | ||
// No-op case | ||
if index1 == index2 { | ||
return | ||
} | ||
|
||
first := k.MustGetSubaccountAtIndex(store, index1) | ||
second := k.MustGetSubaccountAtIndex(store, index2) | ||
k.SetSubaccountAtIndex(store, first, index2) | ||
k.SetSubaccountAtIndex(store, second, index1) | ||
} | ||
|
||
// Less returns true if the element at the first index is less than | ||
// the element at the second index. | ||
func (k Keeper) Less( | ||
ctx sdk.Context, | ||
store prefix.Store, | ||
first uint32, | ||
second uint32, | ||
) bool { | ||
firstSubaccountId := k.MustGetSubaccountAtIndex(store, first) | ||
secondSubaccountId := k.MustGetSubaccountAtIndex(store, second) | ||
|
||
firstRisk, err := k.GetNetCollateralAndMarginRequirements( | ||
ctx, | ||
types.Update{ | ||
SubaccountId: firstSubaccountId, | ||
}, | ||
) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
secondRisk, err := k.GetNetCollateralAndMarginRequirements( | ||
ctx, | ||
types.Update{ | ||
SubaccountId: secondSubaccountId, | ||
}, | ||
) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
// Compare the risks of the two subaccounts and sort | ||
// them in descending order. | ||
return firstRisk.Cmp(secondRisk) > 0 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
package keeper_test | ||
|
||
import ( | ||
"math/big" | ||
"math/rand" | ||
"testing" | ||
|
||
"cosmossdk.io/store/prefix" | ||
sdk "github.com/cosmos/cosmos-sdk/types" | ||
"github.com/dydxprotocol/v4-chain/protocol/app/config" | ||
"github.com/dydxprotocol/v4-chain/protocol/testutil/constants" | ||
keepertest "github.com/dydxprotocol/v4-chain/protocol/testutil/keeper" | ||
testutil "github.com/dydxprotocol/v4-chain/protocol/testutil/util" | ||
"github.com/dydxprotocol/v4-chain/protocol/x/subaccounts/keeper" | ||
satypes "github.com/dydxprotocol/v4-chain/protocol/x/subaccounts/types" | ||
"github.com/stretchr/testify/require" | ||
"gopkg.in/typ.v4/slices" | ||
) | ||
|
||
func TestSafetyHeapInsertRemoveMin(t *testing.T) { | ||
perpetualId := uint32(0) | ||
side := satypes.Long | ||
totalSubaccounts := 1000 | ||
|
||
// Create 1000 subaccounts with balances ranging from -500 to 500. | ||
// The subaccounts should be sorted by balance. | ||
allSubaccounts := make([]satypes.Subaccount, 0) | ||
for i := 0; i < totalSubaccounts; i++ { | ||
subaccount := satypes.Subaccount{ | ||
Id: &satypes.SubaccountId{ | ||
Owner: sdk.MustBech32ifyAddressBytes( | ||
config.Bech32PrefixAccAddr, | ||
constants.AliceAccAddress, | ||
), | ||
Number: uint32(i), | ||
}, | ||
AssetPositions: testutil.CreateUsdcAssetPositions( | ||
// Create asset positions with balances ranging from -500 to 500. | ||
big.NewInt(int64(i - totalSubaccounts/2)), | ||
), | ||
} | ||
|
||
// Handle special case. | ||
if i-totalSubaccounts/2 == 0 { | ||
subaccount.AssetPositions = nil | ||
} | ||
|
||
allSubaccounts = append(allSubaccounts, subaccount) | ||
} | ||
|
||
for iter := 0; iter < 100; iter++ { | ||
// Setup keeper state and test parameters. | ||
ctx, subaccountsKeeper, _, _, _, _, _, _, _, _ := keepertest.SubaccountsKeepers(t, false) | ||
|
||
// Shuffle the subaccounts so that insertion order is random. | ||
slices.Shuffle(allSubaccounts) | ||
|
||
store := subaccountsKeeper.GetSafetyHeapStore(ctx, perpetualId, side) | ||
for i, subaccount := range allSubaccounts { | ||
subaccountsKeeper.SetSubaccount(ctx, subaccount) | ||
subaccountsKeeper.AddSubaccountToSafetyHeap( | ||
ctx, | ||
*subaccount.Id, | ||
perpetualId, | ||
side, | ||
) | ||
|
||
require.Equal( | ||
t, | ||
uint32(i+1), | ||
subaccountsKeeper.GetSafetyHeapLength(store), | ||
) | ||
} | ||
|
||
// Make sure subaccounts are sorted correctly. | ||
for i := 0; i < totalSubaccounts; i++ { | ||
// Get the subaccount with the lowest safety score. | ||
// In this case, the subaccount with the lowest USDC balance. | ||
subaccountId := subaccountsKeeper.MustGetSubaccountAtIndex(store, uint32(0)) | ||
subaccount := subaccountsKeeper.GetSubaccount(ctx, subaccountId) | ||
|
||
// Subaccounts should be sorted by asset position balance. | ||
require.Equal(t, uint32(i), subaccountId.Number) | ||
require.Equal( | ||
t, | ||
big.NewInt(int64(i-totalSubaccounts/2)), | ||
subaccount.GetUsdcPosition(), | ||
) | ||
|
||
// Remove the subaccount from the heap. | ||
subaccountsKeeper.RemoveSubaccountFromSafetyHeap( | ||
ctx, | ||
subaccountId, | ||
perpetualId, | ||
side, | ||
) | ||
require.Equal( | ||
t, | ||
uint32(totalSubaccounts-i-1), | ||
subaccountsKeeper.GetSafetyHeapLength(store), | ||
) | ||
} | ||
} | ||
} | ||
|
||
func TestSafetyHeapInsertRemoveIndex(t *testing.T) { | ||
perpetualId := uint32(0) | ||
side := satypes.Long | ||
totalSubaccounts := 100 | ||
|
||
// Create 1000 subaccounts with balances ranging from -500 to 500. | ||
// The subaccounts should be sorted by balance. | ||
allSubaccounts := make([]satypes.Subaccount, 0) | ||
for i := 0; i < totalSubaccounts; i++ { | ||
subaccount := satypes.Subaccount{ | ||
Id: &satypes.SubaccountId{ | ||
Owner: sdk.MustBech32ifyAddressBytes( | ||
config.Bech32PrefixAccAddr, | ||
constants.AliceAccAddress, | ||
), | ||
Number: uint32(i), | ||
}, | ||
AssetPositions: testutil.CreateUsdcAssetPositions( | ||
// Create asset positions with balances ranging from -500 to 500. | ||
big.NewInt(int64(i - totalSubaccounts/2)), | ||
), | ||
} | ||
|
||
// Handle special case. | ||
if i-totalSubaccounts/2 == 0 { | ||
subaccount.AssetPositions = nil | ||
} | ||
|
||
allSubaccounts = append(allSubaccounts, subaccount) | ||
} | ||
|
||
for iter := 0; iter < 100; iter++ { | ||
// Setup keeper state and test parameters. | ||
ctx, subaccountsKeeper, _, _, _, _, _, _, _, _ := keepertest.SubaccountsKeepers(t, false) | ||
|
||
// Shuffle the subaccounts so that insertion order is random. | ||
slices.Shuffle(allSubaccounts) | ||
|
||
store := subaccountsKeeper.GetSafetyHeapStore(ctx, perpetualId, side) | ||
for i, subaccount := range allSubaccounts { | ||
subaccountsKeeper.SetSubaccount(ctx, subaccount) | ||
subaccountsKeeper.AddSubaccountToSafetyHeap( | ||
ctx, | ||
*subaccount.Id, | ||
perpetualId, | ||
side, | ||
) | ||
|
||
require.Equal( | ||
t, | ||
uint32(i+1), | ||
subaccountsKeeper.GetSafetyHeapLength(store), | ||
) | ||
} | ||
|
||
for i := totalSubaccounts; i > 0; i-- { | ||
// Remove a random subaccount from the heap. | ||
index := rand.Intn(i) | ||
|
||
subaccountId := subaccountsKeeper.MustGetSubaccountAtIndex(store, uint32(index)) | ||
subaccountsKeeper.RemoveSubaccountFromSafetyHeap( | ||
ctx, | ||
subaccountId, | ||
perpetualId, | ||
side, | ||
) | ||
|
||
require.Equal( | ||
t, | ||
uint32(i-1), | ||
subaccountsKeeper.GetSafetyHeapLength(store), | ||
) | ||
|
||
// Verify that the heap property is restored. | ||
verifyHeapProperties(t, subaccountsKeeper, ctx, store, 0) | ||
} | ||
} | ||
} | ||
|
||
func verifyHeapProperties(t *testing.T, k *keeper.Keeper, ctx sdk.Context, store prefix.Store, index uint32) { | ||
length := k.GetSafetyHeapLength(store) | ||
leftChild, rightChild := 2*index+1, 2*index+2 | ||
|
||
if leftChild < length { | ||
require.True(t, k.Less(ctx, store, index, leftChild)) | ||
verifyHeapProperties(t, k, ctx, store, leftChild) | ||
} | ||
|
||
if rightChild < length { | ||
require.True(t, k.Less(ctx, store, index, rightChild)) | ||
verifyHeapProperties(t, k, ctx, store, rightChild) | ||
} | ||
} |