diff --git a/.changeset/entries/861a96d7026170f055a293299b782c79785a8c6c3f7034ebe62f3ea878266583.yaml b/.changeset/entries/861a96d7026170f055a293299b782c79785a8c6c3f7034ebe62f3ea878266583.yaml new file mode 100644 index 00000000..ab62de2a --- /dev/null +++ b/.changeset/entries/861a96d7026170f055a293299b782c79785a8c6c3f7034ebe62f3ea878266583.yaml @@ -0,0 +1,6 @@ +type: refactor +module: none +pull_request: 208 +description: Apply Cosmos SDK v0.50 conventions +backward_compatible: true +date: 2024-12-10T08:38:13.018999Z diff --git a/x/assets/keeper/keeper.go b/x/assets/keeper/keeper.go index e8f4e417..fcaa85f3 100644 --- a/x/assets/keeper/keeper.go +++ b/x/assets/keeper/keeper.go @@ -13,8 +13,7 @@ import ( ) type Keeper struct { - cdc codec.Codec - storeService corestoretypes.KVStoreService + cdc codec.Codec Schema collections.Schema Assets collections.Map[string, types.Asset] // denom => types.Asset @@ -30,8 +29,7 @@ func NewKeeper( ) *Keeper { sb := collections.NewSchemaBuilder(storeService) k := &Keeper{ - cdc: cdc, - storeService: storeService, + cdc: cdc, Assets: collections.NewMap( sb, diff --git a/x/liquidvesting/keeper/send_restriction_test.go b/x/liquidvesting/keeper/send_restriction_test.go index 99183a4e..f1eb9f4c 100644 --- a/x/liquidvesting/keeper/send_restriction_test.go +++ b/x/liquidvesting/keeper/send_restriction_test.go @@ -81,9 +81,8 @@ func (suite *KeeperTestSuite) TestKeeper_BankSend() { lockedDenom, err := types.GetLockedRepresentationDenom("stake") suite.Require().NoError(err) - pool, found, err := suite.pk.GetPool(ctx, 1) + pool, err := suite.pk.GetPool(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) poolCoins := suite.bk.GetAllBalances(ctx, sdk.MustAccAddressFromBech32(pool.GetAddress())) suite.Require().Equal(sdk.NewCoins(sdk.NewInt64Coin(lockedDenom, 1000)), poolCoins) @@ -126,9 +125,8 @@ func (suite *KeeperTestSuite) TestKeeper_BankSend() { lockedDenom, err := types.GetLockedRepresentationDenom("stake") suite.Require().NoError(err) - operator, found, err := suite.ok.GetOperator(ctx, 1) + operator, err := suite.ok.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) poolCoins := suite.bk.GetAllBalances(ctx, sdk.MustAccAddressFromBech32(operator.GetAddress())) suite.Require().Equal(sdk.NewCoins(sdk.NewInt64Coin(lockedDenom, 1000)), poolCoins) @@ -173,9 +171,8 @@ func (suite *KeeperTestSuite) TestKeeper_BankSend() { lockedDenom, err := types.GetLockedRepresentationDenom("stake") suite.Require().NoError(err) - service, found, err := suite.sk.GetService(ctx, 1) + service, err := suite.sk.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) poolCoins := suite.bk.GetAllBalances(ctx, sdk.MustAccAddressFromBech32(service.GetAddress())) suite.Require().Equal(sdk.NewCoins(sdk.NewInt64Coin(lockedDenom, 1000)), poolCoins) diff --git a/x/operators/abci_test.go b/x/operators/abci_test.go index 4a0a8de2..4b9eb0f5 100644 --- a/x/operators/abci_test.go +++ b/x/operators/abci_test.go @@ -78,9 +78,8 @@ func TestBeginBlocker(t *testing.T) { }, check: func(ctx sdk.Context) { // Make sure the operator is still inactivating - operator, found, err := operatorsKeeper.GetOperator(ctx, 1) + operator, err := operatorsKeeper.GetOperator(ctx, 1) require.NoError(t, err) - require.True(t, found) require.Equal(t, types.OPERATOR_STATUS_INACTIVATING, operator.Status) // Make sure the operator is still in the inactivating queue @@ -116,9 +115,8 @@ func TestBeginBlocker(t *testing.T) { }, check: func(ctx sdk.Context) { // Make sure the operator is inactive - operator, found, err := operatorsKeeper.GetOperator(ctx, 1) + operator, err := operatorsKeeper.GetOperator(ctx, 1) require.NoError(t, err) - require.True(t, found) require.Equal(t, types.OPERATOR_STATUS_INACTIVE, operator.Status) // Make sure the operator is not in the inactivating queue diff --git a/x/operators/keeper/alias_functions.go b/x/operators/keeper/alias_functions.go index 6426e303..cc001e0a 100644 --- a/x/operators/keeper/alias_functions.go +++ b/x/operators/keeper/alias_functions.go @@ -2,9 +2,10 @@ package keeper import ( "context" - "fmt" + "errors" "time" + "cosmossdk.io/collections" storetypes "cosmossdk.io/store/types" "github.com/cosmos/cosmos-sdk/runtime" "github.com/cosmos/cosmos-sdk/telemetry" @@ -23,29 +24,10 @@ func (k *Keeper) createAccountIfNotExists(ctx context.Context, address sdk.AccAd // IterateOperators iterates over the operators in the store and performs a callback function func (k *Keeper) IterateOperators(ctx context.Context, cb func(operator types.Operator) (stop bool, err error)) error { - iterator, err := k.operators.Iterate(ctx, nil) - if err != nil { - return err - } - defer iterator.Close() - - for ; iterator.Valid(); iterator.Next() { - operator, err := iterator.Value() - if err != nil { - return err - } - - stop, err := cb(operator) - if err != nil { - return err - } - - if stop { - break - } - } - - return nil + err := k.operators.Walk(ctx, nil, func(_ uint32, operator types.Operator) (stop bool, err error) { + return cb(operator) + }) + return err } // GetOperators returns the operators stored in the KVStore @@ -63,15 +45,13 @@ func (k *Keeper) GetOperators(ctx context.Context) ([]types.Operator, error) { func (k *Keeper) IterateInactivatingOperatorQueue(ctx context.Context, endTime time.Time, fn func(operator types.Operator) (stop bool, err error)) error { return k.iterateInactivatingOperatorsKeys(ctx, endTime, func(key, value []byte) (stop bool, err error) { operatorID, _ := types.SplitInactivatingOperatorQueueKey(key) - operator, found, err := k.GetOperator(ctx, operatorID) + operator, err := k.GetOperator(ctx, operatorID) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + return true, types.ErrOperatorNotFound + } return true, err } - - if !found { - return true, fmt.Errorf("operator %d does not exist", operatorID) - } - return fn(operator) }) } @@ -126,26 +106,13 @@ func (k *Keeper) IsOperatorAddress(ctx context.Context, address string) (bool, e // GetAllOperatorParamsRecords returns all the operator params records func (k *Keeper) GetAllOperatorParamsRecords(ctx context.Context) ([]types.OperatorParamsRecord, error) { - iterator, err := k.operatorParams.Iterate(ctx, nil) - if err != nil { - return nil, err - } - defer iterator.Close() - var records []types.OperatorParamsRecord - for ; iterator.Valid(); iterator.Next() { - // Get the operator params - params, err := iterator.Value() - if err != nil { - return nil, err - } - // Get the operator id from the map key - operatorID, err := iterator.Key() - if err != nil { - return nil, err - } + err := k.operatorParams.Walk(ctx, nil, func(operatorID uint32, params types.OperatorParams) (stop bool, err error) { records = append(records, types.NewOperatorParamsRecord(operatorID, params)) + return false, nil + }) + if err != nil { + return nil, err } - return records, nil } diff --git a/x/operators/keeper/genesis.go b/x/operators/keeper/genesis.go index fd7a39aa..4b7e0ffe 100644 --- a/x/operators/keeper/genesis.go +++ b/x/operators/keeper/genesis.go @@ -1,8 +1,10 @@ package keeper import ( - "fmt" + "errors" + "cosmossdk.io/collections" + errorsmod "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/milkyway-labs/milkyway/v3/x/operators/types" @@ -65,15 +67,14 @@ func (k *Keeper) InitGenesis(ctx sdk.Context, state *types.GenesisState) error { // Store the operator params for _, operatorParams := range state.OperatorsParams { // Ensure that the operator is present - _, found, err := k.GetOperator(ctx, operatorParams.OperatorID) + _, err := k.GetOperator(ctx, operatorParams.OperatorID) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + return errorsmod.Wrapf(types.ErrOperatorNotFound, "operator %d not found", operatorParams.OperatorID) + } return err } - if !found { - return fmt.Errorf("can't set operator params for %d, operator not found", operatorParams.OperatorID) - } - err = k.SaveOperatorParams(ctx, operatorParams.OperatorID, operatorParams.Params) if err != nil { return err diff --git a/x/operators/keeper/genesis_test.go b/x/operators/keeper/genesis_test.go index 556b816f..e4378359 100644 --- a/x/operators/keeper/genesis_test.go +++ b/x/operators/keeper/genesis_test.go @@ -289,9 +289,8 @@ func (suite *KeeperTestSuite) TestKeeper_InitGenesis() { Params: types.DefaultParams(), }, check: func(ctx sdk.Context) { - operator, found, err := suite.k.GetOperator(ctx, 1) + operator, err := suite.k.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, @@ -305,9 +304,8 @@ func (suite *KeeperTestSuite) TestKeeper_InitGenesis() { params, err := suite.k.GetOperatorParams(ctx, 1) suite.Require().Equal(types.DefaultOperatorParams(), params) - operator, found, err = suite.k.GetOperator(ctx, 2) + operator, err = suite.k.GetOperator(ctx, 2) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 2, types.OPERATOR_STATUS_INACTIVATING, diff --git a/x/operators/keeper/grpc_query.go b/x/operators/keeper/grpc_query.go index eb4e41ba..954b9f6c 100644 --- a/x/operators/keeper/grpc_query.go +++ b/x/operators/keeper/grpc_query.go @@ -2,9 +2,9 @@ package keeper import ( "context" + "errors" - "cosmossdk.io/store/prefix" - "github.com/cosmos/cosmos-sdk/runtime" + "cosmossdk.io/collections" "github.com/cosmos/cosmos-sdk/types/query" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -16,37 +16,25 @@ var _ types.QueryServer = &Keeper{} // Operator implements the Query/Operator gRPC method func (k *Keeper) Operator(ctx context.Context, request *types.QueryOperatorRequest) (*types.QueryOperatorResponse, error) { - operator, found, err := k.GetOperator(ctx, request.OperatorId) + operator, err := k.GetOperator(ctx, request.OperatorId) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + return nil, status.Error(codes.NotFound, "operator not found") + } return nil, status.Error(codes.Internal, err.Error()) } - if !found { - return nil, status.Error(codes.NotFound, "operator not found") - } - return &types.QueryOperatorResponse{Operator: operator}, nil } // Operators implements the Query/Operators gRPC method func (k *Keeper) Operators(ctx context.Context, request *types.QueryOperatorsRequest) (*types.QueryOperatorsResponse, error) { - store := k.storeService.OpenKVStore(ctx) - operatorsStore := prefix.NewStore(runtime.KVStoreAdapter(store), types.OperatorPrefix) - - var operators []types.Operator - pageRes, err := query.Paginate(operatorsStore, request.Pagination, func(key []byte, value []byte) error { - var operator types.Operator - if err := k.cdc.Unmarshal(value, &operator); err != nil { - return status.Error(codes.Internal, err.Error()) - } - - operators = append(operators, operator) - return nil + operators, pageRes, err := query.CollectionPaginate(ctx, k.operators, request.Pagination, func(_ uint32, operator types.Operator) (types.Operator, error) { + return operator, nil }) if err != nil { return nil, status.Error(codes.Internal, err.Error()) } - return &types.QueryOperatorsResponse{ Operators: operators, Pagination: pageRes, @@ -54,15 +42,14 @@ func (k *Keeper) Operators(ctx context.Context, request *types.QueryOperatorsReq } func (k *Keeper) OperatorParams(ctx context.Context, request *types.QueryOperatorParamsRequest) (*types.QueryOperatorParamsResponse, error) { - _, found, err := k.GetOperator(ctx, request.OperatorId) + _, err := k.GetOperator(ctx, request.OperatorId) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + return nil, status.Error(codes.NotFound, "operator not found") + } return nil, status.Error(codes.Internal, err.Error()) } - if !found { - return nil, types.ErrOperatorNotFound - } - params, err := k.GetOperatorParams(ctx, request.OperatorId) if err != nil { return nil, err diff --git a/x/operators/keeper/msg_server.go b/x/operators/keeper/msg_server.go index d5c13e09..19dfc250 100644 --- a/x/operators/keeper/msg_server.go +++ b/x/operators/keeper/msg_server.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "cosmossdk.io/collections" "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" @@ -98,15 +99,14 @@ func (k msgServer) RegisterOperator(ctx context.Context, msg *types.MsgRegisterO // UpdateOperator defines the rpc method for Msg/UpdateOperator func (k msgServer) UpdateOperator(ctx context.Context, msg *types.MsgUpdateOperator) (*types.MsgUpdateOperatorResponse, error) { // Check if the operator exists - operator, found, err := k.GetOperator(ctx, msg.OperatorID) + operator, err := k.GetOperator(ctx, msg.OperatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, types.ErrOperatorNotFound + } return nil, err } - if !found { - return nil, types.ErrOperatorNotFound - } - // Make sure only the admin can update the operator if operator.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can update the operator") @@ -141,15 +141,14 @@ func (k msgServer) UpdateOperator(ctx context.Context, msg *types.MsgUpdateOpera // DeactivateOperator defines the rpc method for Msg/DeactivateOperator func (k msgServer) DeactivateOperator(ctx context.Context, msg *types.MsgDeactivateOperator) (*types.MsgDeactivateOperatorResponse, error) { // Check if the operator exists - operator, found, err := k.GetOperator(ctx, msg.OperatorID) + operator, err := k.GetOperator(ctx, msg.OperatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, types.ErrOperatorNotFound + } return nil, err } - if !found { - return nil, types.ErrOperatorNotFound - } - // Make sure only the admin can deactivate the operator if operator.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can deactivate the operator") @@ -175,15 +174,14 @@ func (k msgServer) DeactivateOperator(ctx context.Context, msg *types.MsgDeactiv // ReactivateOperator defines the rpc method for Msg/ReactivateOperator func (k msgServer) ReactivateOperator(ctx context.Context, msg *types.MsgReactivateOperator) (*types.MsgReactivateOperatorResponse, error) { // Check if the operator exists - operator, found, err := k.GetOperator(ctx, msg.OperatorID) + operator, err := k.GetOperator(ctx, msg.OperatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, types.ErrOperatorNotFound + } return nil, err } - if !found { - return nil, types.ErrOperatorNotFound - } - // Make sure only the admin can reactivate the operator if operator.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can deactivate the operator") @@ -209,15 +207,14 @@ func (k msgServer) ReactivateOperator(ctx context.Context, msg *types.MsgReactiv // DeleteOperator defines the rpc method for Msg/DeleteOperator func (k msgServer) DeleteOperator(ctx context.Context, msg *types.MsgDeleteOperator) (*types.MsgDeleteOperatorResponse, error) { // Check if the operator exists - operator, found, err := k.GetOperator(ctx, msg.OperatorID) + operator, err := k.GetOperator(ctx, msg.OperatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, types.ErrOperatorNotFound + } return nil, err } - if !found { - return nil, types.ErrOperatorNotFound - } - // Make sure only the admin can delete the operator if operator.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can delete the operator") @@ -243,15 +240,14 @@ func (k msgServer) DeleteOperator(ctx context.Context, msg *types.MsgDeleteOpera // TransferOperatorOwnership defines the rpc method for Msg/TransferOperatorOwnership func (k msgServer) TransferOperatorOwnership(ctx context.Context, msg *types.MsgTransferOperatorOwnership) (*types.MsgTransferOperatorOwnershipResponse, error) { // Check if the operator exists - operator, found, err := k.GetOperator(ctx, msg.OperatorID) + operator, err := k.GetOperator(ctx, msg.OperatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, types.ErrOperatorNotFound + } return nil, err } - if !found { - return nil, types.ErrOperatorNotFound - } - // Make sure only the admin can transfer the operator ownership if operator.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can transfer the operator ownership") @@ -278,15 +274,14 @@ func (k msgServer) TransferOperatorOwnership(ctx context.Context, msg *types.Msg // SetOperatorParams defines the rpc method for Msg/SetOperatorParams func (k msgServer) SetOperatorParams(ctx context.Context, msg *types.MsgSetOperatorParams) (*types.MsgSetOperatorParamsResponse, error) { - operator, found, err := k.GetOperator(ctx, msg.OperatorID) + operator, err := k.GetOperator(ctx, msg.OperatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, types.ErrOperatorNotFound + } return nil, err } - if !found { - return nil, types.ErrOperatorNotFound - } - // Make sure only the admin can update the operator if operator.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can update the operator params") diff --git a/x/operators/keeper/msg_server_test.go b/x/operators/keeper/msg_server_test.go index e7d46f47..bdda383d 100644 --- a/x/operators/keeper/msg_server_test.go +++ b/x/operators/keeper/msg_server_test.go @@ -3,6 +3,7 @@ package keeper_test import ( "time" + "cosmossdk.io/collections" sdkmath "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" @@ -122,9 +123,8 @@ func (suite *KeeperTestSuite) TestMsgServer_RegisterOperator() { }, check: func(ctx sdk.Context) { // Make sure the operator was stored - stored, found, err := suite.k.GetOperator(ctx, 2) + stored, err := suite.k.GetOperator(ctx, 2) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 2, types.OPERATOR_STATUS_ACTIVE, @@ -210,9 +210,8 @@ func (suite *KeeperTestSuite) TestMsgServer_RegisterOperator() { }, check: func(ctx sdk.Context) { // Make sure the operator was stored - stored, found, err := suite.k.GetOperator(ctx, 2) + stored, err := suite.k.GetOperator(ctx, 2) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 2, types.OPERATOR_STATUS_ACTIVE, @@ -387,9 +386,8 @@ func (suite *KeeperTestSuite) TestMsgServer_UpdateOperator() { }, check: func(ctx sdk.Context) { // Make sure the operator was updated - stored, found, err := suite.k.GetOperator(ctx, 1) + stored, err := suite.k.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, @@ -548,9 +546,8 @@ func (suite *KeeperTestSuite) TestMsgServer_DeactivateOperator() { }, check: func(ctx sdk.Context) { // Make sure the operator was updated - stored, found, err := suite.k.GetOperator(ctx, 1) + stored, err := suite.k.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 1, types.OPERATOR_STATUS_INACTIVATING, @@ -703,9 +700,8 @@ func (suite *KeeperTestSuite) TestMsgServer_ReactivateOperator() { ), }, check: func(ctx sdk.Context) { - operator, found, err := suite.k.GetOperator(ctx, 1) + operator, err := suite.k.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, @@ -823,9 +819,8 @@ func (suite *KeeperTestSuite) TestMsgServer_TransferOperatorOwnership() { }, check: func(ctx sdk.Context) { // Make sure the operator was updated - stored, found, err := suite.k.GetOperator(ctx, 1) + stored, err := suite.k.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, @@ -971,9 +966,8 @@ func (suite *KeeperTestSuite) TestMsgServer_DeleteOperator() { }, check: func(ctx sdk.Context) { // Make sure the operator was updated - _, found, err := suite.k.GetOperator(ctx, 1) - suite.Require().NoError(err) - suite.Require().False(found) + _, err := suite.k.GetOperator(ctx, 1) + suite.Require().ErrorIs(err, collections.ErrNotFound) // Ensure the hook has been called suite.Require().True(suite.hooks.CalledMap["BeforeOperatorDeleted"]) diff --git a/x/operators/keeper/operators.go b/x/operators/keeper/operators.go index 16adc5df..ab65ed2b 100644 --- a/x/operators/keeper/operators.go +++ b/x/operators/keeper/operators.go @@ -56,15 +56,8 @@ func (k *Keeper) CreateOperator(ctx context.Context, operator types.Operator) er // GetOperator returns the operator with the given ID. // If the operator does not exist, false is returned. -func (k *Keeper) GetOperator(ctx context.Context, operatorID uint32) (operator types.Operator, found bool, err error) { - operator, err = k.operators.Get(ctx, operatorID) - if err != nil { - if errors.IsOf(err, collections.ErrNotFound) { - return types.Operator{}, false, nil - } - return types.Operator{}, false, err - } - return operator, true, nil +func (k *Keeper) GetOperator(ctx context.Context, operatorID uint32) (operator types.Operator, err error) { + return k.operators.Get(ctx, operatorID) } // SaveOperator stores the given operator in the KVStore diff --git a/x/operators/keeper/operators_test.go b/x/operators/keeper/operators_test.go index e6218b39..8d5dc5d7 100644 --- a/x/operators/keeper/operators_test.go +++ b/x/operators/keeper/operators_test.go @@ -3,6 +3,7 @@ package keeper_test import ( "time" + "cosmossdk.io/collections" sdkmath "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" @@ -147,9 +148,8 @@ func (suite *KeeperTestSuite) TestKeeper_CreateOperator() { shouldErr: false, check: func(ctx sdk.Context) { // Make sure the operator has been stored - stored, found, err := suite.k.GetOperator(ctx, 1) + stored, err := suite.k.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, @@ -198,14 +198,12 @@ func (suite *KeeperTestSuite) TestKeeper_GetOperator() { setup func() store func(ctx sdk.Context) operatorID uint32 - shouldErr bool expFound bool expOperator types.Operator }{ { name: "non existing operator returns false", operatorID: 1, - shouldErr: false, expFound: false, }, { @@ -223,7 +221,6 @@ func (suite *KeeperTestSuite) TestKeeper_GetOperator() { }, operatorID: 1, expFound: true, - shouldErr: false, expOperator: types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, @@ -248,15 +245,12 @@ func (suite *KeeperTestSuite) TestKeeper_GetOperator() { tc.store(ctx) } - operator, found, err := suite.k.GetOperator(ctx, tc.operatorID) - if tc.shouldErr { - suite.Require().Error(err) - } else { + operator, err := suite.k.GetOperator(ctx, tc.operatorID) + if tc.expFound { suite.Require().NoError(err) - suite.Require().Equal(tc.expFound, found) - if tc.expFound { - suite.Require().Equal(tc.expOperator, operator) - } + suite.Require().Equal(tc.expOperator, operator) + } else { + suite.Require().ErrorIs(err, collections.ErrNotFound) } }) } @@ -283,9 +277,8 @@ func (suite *KeeperTestSuite) TestKeeper_SaveOperator() { ), shouldErr: false, check: func(ctx sdk.Context) { - stored, found, err := suite.k.GetOperator(ctx, 1) + stored, err := suite.k.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, @@ -319,9 +312,8 @@ func (suite *KeeperTestSuite) TestKeeper_SaveOperator() { ), shouldErr: false, check: func(ctx sdk.Context) { - stored, found, err := suite.k.GetOperator(ctx, 1) + stored, err := suite.k.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 1, types.OPERATOR_STATUS_INACTIVATING, @@ -431,9 +423,8 @@ func (suite *KeeperTestSuite) TestKeeper_StartOperatorInactivation() { ), check: func(ctx sdk.Context) { // Make sure the operator status has been updated - stored, found, err := suite.k.GetOperator(ctx, 1) + stored, err := suite.k.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 1, types.OPERATOR_STATUS_INACTIVATING, @@ -524,9 +515,8 @@ func (suite *KeeperTestSuite) TestKeeper_CompleteOperatorInactivation() { ), check: func(ctx sdk.Context) { // Make sure the operator status has been updated - stored, found, err := suite.k.GetOperator(ctx, 1) + stored, err := suite.k.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 1, types.OPERATOR_STATUS_INACTIVE, @@ -637,9 +627,8 @@ func (suite *KeeperTestSuite) TestKeeper_ReactivateInactiveOperator() { operatorID: 1, shouldErr: false, check: func(ctx sdk.Context) { - operator, found, err := suite.k.GetOperator(ctx, 1) + operator, err := suite.k.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewOperator( 1, types.OPERATOR_STATUS_ACTIVE, @@ -665,11 +654,8 @@ func (suite *KeeperTestSuite) TestKeeper_ReactivateInactiveOperator() { if tc.store != nil { tc.store(ctx) } - operator, found, err := suite.k.GetOperator(ctx, tc.operatorID) + operator, err := suite.k.GetOperator(ctx, tc.operatorID) suite.Require().NoError(err) - if !found { - suite.Fail("operator not found") - } err = suite.k.ReactivateInactiveOperator(ctx, operator) if tc.shouldErr { diff --git a/x/pools/keeper/alias_functions.go b/x/pools/keeper/alias_functions.go index a220e3f9..6085206d 100644 --- a/x/pools/keeper/alias_functions.go +++ b/x/pools/keeper/alias_functions.go @@ -19,29 +19,10 @@ func (k *Keeper) createAccountIfNotExists(ctx context.Context, address sdk.AccAd // IteratePools iterates over the pools in the store and performs a callback function func (k *Keeper) IteratePools(ctx context.Context, cb func(pool types.Pool) (stop bool, err error)) error { - iterator, err := k.pools.Iterate(ctx, nil) - if err != nil { - return err - } - defer iterator.Close() - - for ; iterator.Valid(); iterator.Next() { - pool, err := iterator.Value() - if err != nil { - return err - } - - stop, err := cb(pool) - if err != nil { - return err - } - - if stop { - break - } - } - - return nil + err := k.pools.Walk(ctx, nil, func(_ uint32, pool types.Pool) (stop bool, err error) { + return cb(pool) + }) + return err } // GetPools returns the list of stored pools @@ -57,23 +38,20 @@ func (k *Keeper) GetPools(ctx context.Context) ([]types.Pool, error) { // GetPoolByDenom returns the pool for the given denom if it exists. // If the pool does not exist, false is returned instead func (k *Keeper) GetPoolByDenom(ctx context.Context, denom string) (types.Pool, bool, error) { - iterator, err := k.pools.Iterate(ctx, nil) + var poolFound types.Pool + err := k.pools.Walk(ctx, nil, func(_ uint32, pool types.Pool) (stop bool, err error) { + if pool.Denom == denom { + poolFound = pool + return true, nil + } + return false, nil + }) if err != nil { return types.Pool{}, false, err } - defer iterator.Close() - - for ; iterator.Valid(); iterator.Next() { - pool, err := iterator.Value() - if err != nil { - return types.Pool{}, false, err - } - - if pool.Denom == denom { - return pool, true, nil - } + if poolFound != (types.Pool{}) { + return poolFound, true, nil } - return types.Pool{}, false, nil } diff --git a/x/pools/keeper/grpc_query.go b/x/pools/keeper/grpc_query.go index b6c8d5c5..d878bfc2 100644 --- a/x/pools/keeper/grpc_query.go +++ b/x/pools/keeper/grpc_query.go @@ -2,7 +2,9 @@ package keeper import ( "context" + "errors" + "cosmossdk.io/collections" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/types/query" "google.golang.org/grpc/codes" @@ -19,15 +21,14 @@ func (k *Keeper) PoolByID(ctx context.Context, request *types.QueryPoolByIdReque return nil, status.Error(codes.InvalidArgument, "invalid pool id") } - pool, found, err := k.GetPool(ctx, request.PoolId) + pool, err := k.GetPool(ctx, request.PoolId) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + return nil, status.Error(codes.NotFound, "pool not found") + } return nil, status.Error(codes.Internal, err.Error()) } - if !found { - return nil, status.Error(codes.NotFound, "pool not found") - } - return &types.QueryPoolResponse{Pool: pool}, nil } @@ -51,8 +52,8 @@ func (k *Keeper) PoolByDenom(ctx context.Context, request *types.QueryPoolByDeno // Pools implements the Query/Pools gRPC method func (k *Keeper) Pools(ctx context.Context, request *types.QueryPoolsRequest) (*types.QueryPoolsResponse, error) { - pools, pageRes, err := query.CollectionPaginate(ctx, k.pools, request.Pagination, func(key uint32, value types.Pool) (types.Pool, error) { - return value, nil + pools, pageRes, err := query.CollectionPaginate(ctx, k.pools, request.Pagination, func(_ uint32, pool types.Pool) (types.Pool, error) { + return pool, nil }) if err != nil { return nil, status.Error(codes.Internal, err.Error()) diff --git a/x/pools/keeper/keeper.go b/x/pools/keeper/keeper.go index cb1ceeed..f5ff7be7 100644 --- a/x/pools/keeper/keeper.go +++ b/x/pools/keeper/keeper.go @@ -13,9 +13,8 @@ import ( ) type Keeper struct { - cdc codec.Codec - storeService corestoretypes.KVStoreService - hooks types.PoolsHooks + cdc codec.Codec + hooks types.PoolsHooks accountKeeper types.AccountKeeper @@ -33,7 +32,6 @@ func NewKeeper(cdc codec.Codec, k := &Keeper{ cdc: cdc, - storeService: storeService, accountKeeper: accountKeeper, nextPoolID: collections.NewSequence( diff --git a/x/pools/keeper/pools.go b/x/pools/keeper/pools.go index 18c21ff4..f2be1ea4 100644 --- a/x/pools/keeper/pools.go +++ b/x/pools/keeper/pools.go @@ -3,7 +3,6 @@ package keeper import ( "context" - "cosmossdk.io/collections" "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" @@ -52,13 +51,6 @@ func (k *Keeper) SavePool(ctx context.Context, pool types.Pool) error { // GetPool retrieves the pool with the given ID from the store. // If the pool does not exist, false is returned instead -func (k *Keeper) GetPool(ctx context.Context, id uint32) (types.Pool, bool, error) { - pool, err := k.pools.Get(ctx, id) - if err != nil { - if errors.IsOf(err, collections.ErrNotFound) { - return types.Pool{}, false, nil - } - return types.Pool{}, false, err - } - return pool, true, nil +func (k *Keeper) GetPool(ctx context.Context, id uint32) (types.Pool, error) { + return k.pools.Get(ctx, id) } diff --git a/x/pools/keeper/pools_test.go b/x/pools/keeper/pools_test.go index c270724d..bb583ec6 100644 --- a/x/pools/keeper/pools_test.go +++ b/x/pools/keeper/pools_test.go @@ -1,6 +1,7 @@ package keeper_test import ( + "cosmossdk.io/collections" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/milkyway-labs/milkyway/v3/x/pools/types" @@ -124,9 +125,8 @@ func (suite *KeeperTestSuite) TestKeeper_SavePool() { pool: types.NewPool(1, "uatom"), check: func(ctx sdk.Context) { // Make sure the pool is saved properly - pool, found, err := suite.k.GetPool(ctx, 1) + pool, err := suite.k.GetPool(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewPool(1, "uatom"), pool) // Make sure the pool account is created @@ -147,9 +147,8 @@ func (suite *KeeperTestSuite) TestKeeper_SavePool() { shouldErr: false, check: func(ctx sdk.Context) { // Make sure the pool is saved properly - pool, found, err := suite.k.GetPool(ctx, 1) + pool, err := suite.k.GetPool(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewPool(1, "usdt"), pool) // Make sure the pool account is created @@ -186,14 +185,13 @@ func (suite *KeeperTestSuite) TestKeeper_SavePool() { func (suite *KeeperTestSuite) TestKeeper_GetPool() { testCases := []struct { - name string - setup func() - store func(ctx sdk.Context) - poolID uint32 - shouldErr bool - expFound bool - expPool types.Pool - check func(ctx sdk.Context) + name string + setup func() + store func(ctx sdk.Context) + poolID uint32 + expFound bool + expPool types.Pool + check func(ctx sdk.Context) }{ { name: "not found pool returns error", @@ -223,15 +221,12 @@ func (suite *KeeperTestSuite) TestKeeper_GetPool() { tc.store(ctx) } - pool, found, err := suite.k.GetPool(ctx, tc.poolID) - if tc.shouldErr { - suite.Require().Error(err) - } else { + pool, err := suite.k.GetPool(ctx, tc.poolID) + if tc.expFound { suite.Require().NoError(err) - suite.Require().Equal(tc.expFound, found) - if tc.expFound { - suite.Require().Equal(tc.expPool, pool) - } + suite.Require().Equal(tc.expPool, pool) + } else { + suite.Require().ErrorIs(err, collections.ErrNotFound) } if tc.check != nil { diff --git a/x/restaking/keeper/alias_functions.go b/x/restaking/keeper/alias_functions.go index ae6a80ef..685af19d 100644 --- a/x/restaking/keeper/alias_functions.go +++ b/x/restaking/keeper/alias_functions.go @@ -6,6 +6,7 @@ import ( "sort" "time" + "cosmossdk.io/collections" "cosmossdk.io/errors" "cosmossdk.io/math" storetypes "cosmossdk.io/store/types" @@ -26,41 +27,18 @@ import ( // IterateAllOperatorsJoinedServices iterates over all the operators and their joined services, // performing the given action. If the action returns true, the iteration will stop. func (k *Keeper) IterateAllOperatorsJoinedServices(ctx context.Context, action func(operatorID uint32, serviceID uint32) (stop bool, err error)) error { - iterator, err := k.operatorJoinedServices.Iterate(ctx, nil) - if err != nil { - return err - } - defer iterator.Close() - - for ; iterator.Valid(); iterator.Next() { - operatorServicePair, err := iterator.Key() - if err != nil { - return err - } - - stop, err := action(operatorServicePair.K1(), operatorServicePair.K2()) - if err != nil { - return err - } - - if stop { - break - } - } - - return nil + err := k.operatorJoinedServices.Walk(ctx, nil, func(key collections.Pair[uint32, uint32], _ collections.NoValue) (stop bool, err error) { + operatorID := key.K1() + serviceID := key.K2() + return action(operatorID, serviceID) + }) + return err } // GetAllOperatorsJoinedServices returns all services that each operator has joined func (k *Keeper) GetAllOperatorsJoinedServices(ctx context.Context) ([]types.OperatorJoinedServices, error) { - iterator, err := k.operatorJoinedServices.Iterate(ctx, nil) - if err != nil { - return nil, err - } - defer iterator.Close() - items := make(map[uint32]types.OperatorJoinedServices) - err = k.IterateAllOperatorsJoinedServices(ctx, func(operatorID uint32, serviceID uint32) (stop bool, err error) { + err := k.IterateAllOperatorsJoinedServices(ctx, func(operatorID uint32, serviceID uint32) (stop bool, err error) { joinedServicesRecord, ok := items[operatorID] if !ok { joinedServicesRecord = types.NewOperatorJoinedServices(operatorID, nil) @@ -95,31 +73,12 @@ func (k *Keeper) GetAllOperatorsJoinedServices(ctx context.Context) ([]types.Ope // IterateAllServicesAllowedOperators iterates over all the services and their allowed operators, // performing the given action. If the action returns true, the iteration will stop. func (k *Keeper) IterateAllServicesAllowedOperators(ctx context.Context, action func(serviceID uint32, operatorID uint32) (stop bool, err error)) error { - iterator, err := k.serviceOperatorsAllowList.Iterate(ctx, nil) - if err != nil { - return err - } - defer iterator.Close() - - for ; iterator.Valid(); iterator.Next() { - serviceOperatorPair, err := iterator.Key() - if err != nil { - return err - } - serviceID := serviceOperatorPair.K1() - operatorID := serviceOperatorPair.K2() - - stop, err := action(serviceID, operatorID) - if err != nil { - return err - } - - if stop { - break - } - } - - return nil + err := k.serviceOperatorsAllowList.Walk(ctx, nil, func(key collections.Pair[uint32, uint32]) (stop bool, err error) { + serviceID := key.K1() + operatorID := key.K2() + return action(serviceID, operatorID) + }) + return err } // GetAllServicesAllowedOperators returns all the operators that are allowed to secure a service for all the services @@ -159,29 +118,21 @@ func (k *Keeper) GetAllServicesAllowedOperators(ctx context.Context) ([]types.Se // GetAllServicesSecuringPools returns all the pools from which the services are allowed to borrow security func (k *Keeper) GetAllServicesSecuringPools(ctx context.Context) ([]types.ServiceSecuringPools, error) { - iterator, err := k.serviceSecuringPools.Iterate(ctx, nil) - if err != nil { - return nil, err - } - defer iterator.Close() - items := make(map[uint32]types.ServiceSecuringPools) - for ; iterator.Valid(); iterator.Next() { - servicePoolPair, err := iterator.Key() - if err != nil { - return nil, err - } - serviceID := servicePoolPair.K1() - poolID := servicePoolPair.K2() - + err := k.serviceSecuringPools.Walk(ctx, nil, func(key collections.Pair[uint32, uint32]) (stop bool, err error) { + serviceID := key.K1() + poolID := key.K2() securingPools, ok := items[serviceID] if !ok { securingPools = types.NewServiceSecuringPools(serviceID, nil) } securingPools.PoolIDs = append(securingPools.PoolIDs, poolID) items[serviceID] = securingPools + return false, nil + }) + if err != nil { + return nil, err } - if len(items) == 0 { return nil, nil } @@ -266,7 +217,7 @@ func (k *Keeper) GetDelegationForTarget( // GetDelegationTargetFromDelegation returns the target of the given delegation. func (k *Keeper) GetDelegationTargetFromDelegation( ctx context.Context, delegation types.Delegation, -) (types.DelegationTarget, bool, error) { +) (types.DelegationTarget, error) { switch delegation.Type { case types.DELEGATION_TYPE_POOL: return k.poolsKeeper.GetPool(ctx, delegation.TargetID) @@ -275,7 +226,7 @@ func (k *Keeper) GetDelegationTargetFromDelegation( case types.DELEGATION_TYPE_OPERATOR: return k.operatorsKeeper.GetOperator(ctx, delegation.TargetID) default: - return nil, false, nil + return nil, nil } } @@ -574,15 +525,14 @@ func (k *Keeper) GetAllDelegations(ctx context.Context) ([]types.Delegation, err func (k *Keeper) GetAllUserRestakedCoins(ctx context.Context, userAddress string) (sdk.DecCoins, error) { totalDelegatedCoins := sdk.NewDecCoins() err := k.IterateUserDelegations(ctx, userAddress, func(d types.Delegation) (bool, error) { - target, found, err := k.GetDelegationTargetFromDelegation(ctx, d) + target, err := k.GetDelegationTargetFromDelegation(ctx, d) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return true, fmt.Errorf("can't find target for delegation %d, target id: %d", d.Type, d.TargetID) + } return true, err } - if !found { - return true, fmt.Errorf("can't find target for delegation %d, target id: %d", d.Type, d.TargetID) - } - totalDelegatedCoins = totalDelegatedCoins.Add(target.TokensFromShares(d.Shares)...) return false, nil }) @@ -677,33 +627,33 @@ func (k *Keeper) PerformDelegation(ctx context.Context, data types.DelegationDat func (k *Keeper) getUnbondingDelegationTarget(ctx context.Context, ubd types.UnbondingDelegation) (types.DelegationTarget, error) { switch ubd.Type { case types.DELEGATION_TYPE_POOL: - pool, found, err := k.poolsKeeper.GetPool(ctx, ubd.TargetID) + pool, err := k.poolsKeeper.GetPool(ctx, ubd.TargetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, poolstypes.ErrPoolNotFound + } return nil, err } - if !found { - return nil, poolstypes.ErrPoolNotFound - } return pool, nil case types.DELEGATION_TYPE_OPERATOR: - operator, found, err := k.operatorsKeeper.GetOperator(ctx, ubd.TargetID) + operator, err := k.operatorsKeeper.GetOperator(ctx, ubd.TargetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, operatorstypes.ErrOperatorNotFound + } return nil, err } - if !found { - return nil, operatorstypes.ErrOperatorNotFound - } return operator, nil case types.DELEGATION_TYPE_SERVICE: - service, found, err := k.servicesKeeper.GetService(ctx, ubd.TargetID) + service, err := k.servicesKeeper.GetService(ctx, ubd.TargetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, servicestypes.ErrServiceNotFound + } return nil, err } - if !found { - return nil, servicestypes.ErrServiceNotFound - } return service, nil default: @@ -841,15 +791,14 @@ func (k *Keeper) UnbondRestakedAssets(ctx context.Context, user sdk.AccAddress, toUndelegateTokens := sdk.NewDecCoinsFromCoins(amount...) err := k.IterateUserDelegations(ctx, user.String(), func(delegation types.Delegation) (bool, error) { - target, found, err := k.GetDelegationTargetFromDelegation(ctx, delegation) + target, err := k.GetDelegationTargetFromDelegation(ctx, delegation) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return false, nil + } return true, err } - if !found { - return false, nil - } - // Compute the shares that this delegation should have to undelegate // all the remaining tokens involvedShares, err := target.SharesFromDecCoins(toUndelegateTokens) diff --git a/x/restaking/keeper/alias_functions_test.go b/x/restaking/keeper/alias_functions_test.go index be002a8a..d7461711 100644 --- a/x/restaking/keeper/alias_functions_test.go +++ b/x/restaking/keeper/alias_functions_test.go @@ -812,7 +812,7 @@ func (suite *KeeperTestSuite) TestKeeper_UnbondRestakedAssets() { suite.Require().NoError(err) suite.Assert().True(found) suite.Assert().Equal(types.DELEGATION_TYPE_OPERATOR, del.Type) - operator, _, err := suite.ok.GetOperator(ctx, 1) + operator, err := suite.ok.GetOperator(ctx, 1) suite.Require().NoError(err) suite.Assert().Equal( sdk.NewDecCoins(sdk.NewInt64DecCoin("stake", 50)), diff --git a/x/restaking/keeper/grpc_query.go b/x/restaking/keeper/grpc_query.go index 7c63b28b..17f48c9f 100644 --- a/x/restaking/keeper/grpc_query.go +++ b/x/restaking/keeper/grpc_query.go @@ -38,15 +38,14 @@ func (k Querier) OperatorJoinedServices(ctx context.Context, req *types.QueryOpe return nil, status.Error(codes.InvalidArgument, "operator id cannot be 0") } - _, found, err := k.operatorsKeeper.GetOperator(ctx, req.OperatorId) + _, err := k.operatorsKeeper.GetOperator(ctx, req.OperatorId) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, status.Error(codes.NotFound, "operator not found") + } return nil, status.Error(codes.Internal, err.Error()) } - if !found { - return nil, status.Error(codes.InvalidArgument, "operator not found") - } - // Get the operator joined services serviceIDs, pageResponse, err := query.CollectionPaginate(ctx, k.operatorJoinedServices, req.Pagination, func(key collections.Pair[uint32, uint32], _ collections.NoValue) (uint32, error) { @@ -124,15 +123,14 @@ func (k Querier) ServiceOperators(ctx context.Context, req *types.QueryServiceOp return nil, status.Error(codes.InvalidArgument, "service id cannot be 0") } - _, found, err := k.servicesKeeper.GetService(ctx, req.ServiceId) + _, err := k.servicesKeeper.GetService(ctx, req.ServiceId) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, status.Error(codes.NotFound, "service not found") + } return nil, status.Error(codes.Internal, err.Error()) } - if !found { - return nil, status.Error(codes.NotFound, "service not found") - } - eligibleOperators, pageResponse, err := query.CollectionFilteredPaginate(ctx, k.operatorJoinedServices.Indexes.Service, req.Pagination, // Filter to return only the operators that have joined the service and // that are allowed to validate it @@ -146,15 +144,13 @@ func (k Querier) ServiceOperators(ctx context.Context, req *types.QueryServiceOp // Here is k2 the operator id since the Service index provides association // between a service and the operator securing it operatorID := key.K2() - operator, found, err := k.operatorsKeeper.GetOperator(ctx, operatorID) + operator, err := k.operatorsKeeper.GetOperator(ctx, operatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return operatorstypes.Operator{}, status.Errorf(codes.NotFound, "operator %d not found", operatorID) + } return operatorstypes.Operator{}, err } - - if !found { - return operatorstypes.Operator{}, errors.Wrapf( - operatorstypes.ErrOperatorNotFound, "operator %d not found", operatorID) - } return operator, nil }, query.WithCollectionPaginationPairPrefix[uint32, uint32](req.ServiceId)) if err != nil { @@ -834,15 +830,14 @@ func (k Querier) DelegatorPools(ctx context.Context, req *types.QueryDelegatorPo return err } - pool, found, err := k.poolsKeeper.GetPool(ctx, delegation.TargetID) + pool, err := k.poolsKeeper.GetPool(ctx, delegation.TargetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return poolstypes.ErrPoolNotFound + } return err } - if !found { - return poolstypes.ErrPoolNotFound - } - pools = append(pools, pool) return nil @@ -879,15 +874,14 @@ func (k Querier) DelegatorPool(ctx context.Context, req *types.QueryDelegatorPoo return nil, status.Error(codes.NotFound, "pool delegation not found") } - pool, found, err := k.poolsKeeper.GetPool(ctx, delegation.TargetID) + pool, err := k.poolsKeeper.GetPool(ctx, delegation.TargetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, status.Error(codes.NotFound, "pool not found") + } return nil, status.Error(codes.Internal, err.Error()) } - if !found { - return nil, status.Error(codes.NotFound, "pool not found") - } - return &types.QueryDelegatorPoolResponse{ Pool: pool, }, nil @@ -915,15 +909,14 @@ func (k Querier) DelegatorOperators(ctx context.Context, req *types.QueryDelegat return err } - operator, found, err := k.operatorsKeeper.GetOperator(ctx, delegation.TargetID) + operator, err := k.operatorsKeeper.GetOperator(ctx, delegation.TargetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return operatorstypes.ErrOperatorNotFound + } return err } - if !found { - return operatorstypes.ErrOperatorNotFound - } - operators = append(operators, operator) return nil @@ -960,15 +953,14 @@ func (k Querier) DelegatorOperator(ctx context.Context, req *types.QueryDelegato return nil, status.Error(codes.NotFound, "operator delegation not found") } - operator, found, err := k.operatorsKeeper.GetOperator(ctx, delegation.TargetID) + operator, err := k.operatorsKeeper.GetOperator(ctx, delegation.TargetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, status.Error(codes.NotFound, "operator not found") + } return nil, status.Error(codes.Internal, err.Error()) } - if !found { - return nil, status.Error(codes.NotFound, "operator not found") - } - return &types.QueryDelegatorOperatorResponse{ Operator: operator, }, nil @@ -996,15 +988,14 @@ func (k Querier) DelegatorServices(ctx context.Context, req *types.QueryDelegato return err } - pool, found, err := k.servicesKeeper.GetService(ctx, delegation.TargetID) + pool, err := k.servicesKeeper.GetService(ctx, delegation.TargetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return servicestypes.ErrServiceNotFound + } return err } - if !found { - return servicestypes.ErrServiceNotFound - } - services = append(services, pool) return nil @@ -1041,15 +1032,14 @@ func (k Querier) DelegatorService(ctx context.Context, req *types.QueryDelegator return nil, status.Error(codes.NotFound, "service delegation not found") } - service, found, err := k.servicesKeeper.GetService(ctx, delegation.TargetID) + service, err := k.servicesKeeper.GetService(ctx, delegation.TargetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, status.Error(codes.NotFound, "service not found") + } return nil, status.Error(codes.Internal, err.Error()) } - if !found { - return nil, status.Error(codes.NotFound, "service not found") - } - return &types.QueryDelegatorServiceResponse{ Service: service, }, nil @@ -1089,45 +1079,42 @@ func (k Querier) Params(ctx context.Context, _ *types.QueryParamsRequest) (*type // PoolDelegationToPoolDelegationResponse converts a PoolDelegation to a PoolDelegationResponse func PoolDelegationToPoolDelegationResponse(ctx context.Context, k *Keeper, delegation types.Delegation) (types.DelegationResponse, error) { - pool, found, err := k.poolsKeeper.GetPool(ctx, delegation.TargetID) + pool, err := k.poolsKeeper.GetPool(ctx, delegation.TargetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return types.DelegationResponse{}, poolstypes.ErrPoolNotFound + } return types.DelegationResponse{}, err } - if !found { - return types.DelegationResponse{}, poolstypes.ErrPoolNotFound - } - truncatedBalance, _ := pool.TokensFromShares(delegation.Shares).TruncateDecimal() return types.NewDelegationResponse(delegation, truncatedBalance), nil } // OperatorDelegationToOperatorDelegationResponse converts a OperatorDelegation to a OperatorDelegationResponse func OperatorDelegationToOperatorDelegationResponse(ctx context.Context, k *Keeper, delegation types.Delegation) (types.DelegationResponse, error) { - operator, found, err := k.operatorsKeeper.GetOperator(ctx, delegation.TargetID) + operator, err := k.operatorsKeeper.GetOperator(ctx, delegation.TargetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return types.DelegationResponse{}, operatorstypes.ErrOperatorNotFound + } return types.DelegationResponse{}, err } - if !found { - return types.DelegationResponse{}, operatorstypes.ErrOperatorNotFound - } - truncatedBalance, _ := operator.TokensFromShares(delegation.Shares).TruncateDecimal() return types.NewDelegationResponse(delegation, truncatedBalance), nil } // ServiceDelegationToServiceDelegationResponse converts a ServiceDelegation to a ServiceDelegationResponse func ServiceDelegationToServiceDelegationResponse(ctx context.Context, k *Keeper, delegation types.Delegation) (types.DelegationResponse, error) { - service, found, err := k.servicesKeeper.GetService(ctx, delegation.TargetID) + service, err := k.servicesKeeper.GetService(ctx, delegation.TargetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return types.DelegationResponse{}, servicestypes.ErrServiceNotFound + } return types.DelegationResponse{}, err } - if !found { - return types.DelegationResponse{}, servicestypes.ErrServiceNotFound - } - truncatedBalance, _ := service.TokensFromShares(delegation.Shares).TruncateDecimal() return types.NewDelegationResponse(delegation, truncatedBalance), nil } diff --git a/x/restaking/keeper/invariants.go b/x/restaking/keeper/invariants.go index 606a7e82..8637bbc2 100644 --- a/x/restaking/keeper/invariants.go +++ b/x/restaking/keeper/invariants.go @@ -1,8 +1,10 @@ package keeper import ( + "errors" "fmt" + "cosmossdk.io/collections" sdk "github.com/cosmos/cosmos-sdk/types" operatorstypes "github.com/milkyway-labs/milkyway/v3/x/operators/types" @@ -182,15 +184,14 @@ func PoolsDelegatorsSharesInvariant(k *Keeper) sdk.Invariant { } for poolID, delegatorsShares := range poolsDelegatorsShares { - pool, found, err := k.poolsKeeper.GetPool(ctx, poolID) + pool, err := k.poolsKeeper.GetPool(ctx, poolID) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + panic(fmt.Errorf("pool with id %d not found", poolID)) + } panic(err) } - if !found { - panic(fmt.Errorf("pool with id %d not found", poolID)) - } - sharesAmount := delegatorsShares.AmountOf(pool.GetSharesDenom(pool.Denom)) if !pool.DelegatorShares.Equal(sharesAmount) { broken = true @@ -266,15 +267,14 @@ func OperatorsDelegatorsSharesInvariant(k *Keeper) sdk.Invariant { } for operatorID, delegatorsShares := range operatorsDelegatorsShares { - operator, found, err := k.operatorsKeeper.GetOperator(ctx, operatorID) + operator, err := k.operatorsKeeper.GetOperator(ctx, operatorID) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + panic(fmt.Errorf("operator with id %d not found", operatorID)) + } panic(err) } - if !found { - panic(fmt.Errorf("operator with id %d not found", operatorID)) - } - if !operator.DelegatorShares.Equal(delegatorsShares) { broken = true msg += fmt.Sprintf("operator %d total shares: %v, delegators shares: %v\n", operatorID, operator.DelegatorShares, delegatorsShares) @@ -348,15 +348,14 @@ func ServicesDelegatorsSharesInvariant(k *Keeper) sdk.Invariant { } for serviceID, delegatorsShares := range servicesDelegatorsShares { - service, found, err := k.servicesKeeper.GetService(ctx, serviceID) + service, err := k.servicesKeeper.GetService(ctx, serviceID) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + panic(fmt.Errorf("service with id %d not found", serviceID)) + } panic(err) } - if !found { - panic(fmt.Errorf("service with id %d not found", serviceID)) - } - if !service.DelegatorShares.Equal(delegatorsShares) { broken = true msg += fmt.Sprintf("service %d total shares: %v, delegators shares: %v\n", serviceID, service.DelegatorShares, delegatorsShares) @@ -374,15 +373,13 @@ func AllowedOperatorsExistInvariant(k *Keeper) sdk.Invariant { // Iterate over all the services joined by operators var notFoundOperatorsIDs []uint32 err := k.IterateAllServicesAllowedOperators(ctx, func(serviceID uint32, operatorID uint32) (stop bool, err error) { - _, found, err := k.operatorsKeeper.GetOperator(ctx, serviceID) + _, err = k.operatorsKeeper.GetOperator(ctx, serviceID) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + notFoundOperatorsIDs = append(notFoundOperatorsIDs, operatorID) + } return true, err } - - if !found { - notFoundOperatorsIDs = append(notFoundOperatorsIDs, operatorID) - } - return false, nil }) if err != nil { @@ -405,15 +402,13 @@ func OperatorsJoinedServicesExistInvariant(k *Keeper) sdk.Invariant { // Iterate over all the operators joined services var notFoundServicesIDs []uint32 err := k.IterateAllOperatorsJoinedServices(ctx, func(operatorID uint32, serviceID uint32) (stop bool, err error) { - _, found, err := k.servicesKeeper.GetService(ctx, serviceID) + _, err = k.servicesKeeper.GetService(ctx, serviceID) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + notFoundServicesIDs = append(notFoundServicesIDs, serviceID) + } return false, err } - - if !found { - notFoundServicesIDs = append(notFoundServicesIDs, serviceID) - } - return false, nil }) if err != nil { diff --git a/x/restaking/keeper/msg_server.go b/x/restaking/keeper/msg_server.go index bfab594a..38d74d03 100644 --- a/x/restaking/keeper/msg_server.go +++ b/x/restaking/keeper/msg_server.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + "cosmossdk.io/collections" "cosmossdk.io/errors" "github.com/cosmos/cosmos-sdk/telemetry" sdk "github.com/cosmos/cosmos-sdk/types" @@ -30,28 +31,26 @@ func NewMsgServer(keeper *Keeper) types.MsgServer { // JoinService defines the rpc method for Msg/JoinService func (k msgServer) JoinService(ctx context.Context, msg *types.MsgJoinService) (*types.MsgJoinServiceResponse, error) { - operator, found, err := k.operatorsKeeper.GetOperator(ctx, msg.OperatorID) + operator, err := k.operatorsKeeper.GetOperator(ctx, msg.OperatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, operatorstypes.ErrOperatorNotFound + } return nil, err } - if !found { - return nil, operatorstypes.ErrOperatorNotFound - } - if operator.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can join the service") } - service, found, err := k.servicesKeeper.GetService(ctx, msg.ServiceID) + service, err := k.servicesKeeper.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, errors.Wrapf(sdkerrors.ErrNotFound, "service %d not found", msg.ServiceID) + } return nil, err } - if !found { - return nil, errors.Wrapf(sdkerrors.ErrNotFound, "service %d not found", msg.ServiceID) - } - if !service.IsActive() { return nil, errors.Wrapf(servicestypes.ErrServiceNotActive, "service %d is not active", msg.ServiceID) } @@ -75,28 +74,26 @@ func (k msgServer) JoinService(ctx context.Context, msg *types.MsgJoinService) ( // LeaveService defines the rpc method for Msg/LeaveService func (k msgServer) LeaveService(ctx context.Context, msg *types.MsgLeaveService) (*types.MsgLeaveServiceResponse, error) { - operator, found, err := k.operatorsKeeper.GetOperator(ctx, msg.OperatorID) + operator, err := k.operatorsKeeper.GetOperator(ctx, msg.OperatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, operatorstypes.ErrOperatorNotFound + } return nil, err } - if !found { - return nil, operatorstypes.ErrOperatorNotFound - } - if operator.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can leave the service") } - _, found, err = k.servicesKeeper.GetService(ctx, msg.ServiceID) + _, err = k.servicesKeeper.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, errors.Wrapf(sdkerrors.ErrNotFound, "service %d not found", msg.ServiceID) + } return nil, err } - if !found { - return nil, errors.Wrapf(sdkerrors.ErrNotFound, "service %d not found", msg.ServiceID) - } - err = k.RemoveServiceFromOperatorJoinedServices(ctx, msg.OperatorID, msg.ServiceID) if err != nil { return nil, err @@ -117,25 +114,23 @@ func (k msgServer) LeaveService(ctx context.Context, msg *types.MsgLeaveService) // AddOperatorToAllowList defines the rpc method for Msg/AddOperatorToAllowList func (k msgServer) AddOperatorToAllowList(ctx context.Context, msg *types.MsgAddOperatorToAllowList) (*types.MsgAddOperatorToAllowListResponse, error) { // Ensure that the service exists - service, found, err := k.servicesKeeper.GetService(ctx, msg.ServiceID) + service, err := k.servicesKeeper.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, servicestypes.ErrServiceNotFound + } return nil, err } - if !found { - return nil, servicestypes.ErrServiceNotFound - } - // Ensure that the operator exists - _, found, err = k.operatorsKeeper.GetOperator(ctx, msg.OperatorID) + _, err = k.operatorsKeeper.GetOperator(ctx, msg.OperatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, operatorstypes.ErrOperatorNotFound + } return nil, err } - if !found { - return nil, operatorstypes.ErrOperatorNotFound - } - // Ensure the service admin is performing this action if service.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the service admin can allow an operator") @@ -171,15 +166,14 @@ func (k msgServer) AddOperatorToAllowList(ctx context.Context, msg *types.MsgAdd // RemoveOperatorFromAllowlist defines the rpc method for Msg/RemoveOperatorFromAllowlist func (k msgServer) RemoveOperatorFromAllowlist(ctx context.Context, msg *types.MsgRemoveOperatorFromAllowlist) (*types.MsgRemoveOperatorFromAllowlistResponse, error) { // Ensure that the service exists - service, found, err := k.servicesKeeper.GetService(ctx, msg.ServiceID) + service, err := k.servicesKeeper.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, servicestypes.ErrServiceNotFound + } return nil, err } - if !found { - return nil, servicestypes.ErrServiceNotFound - } - // Ensure the service admin is performing this action if service.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the service admin can allow an operator") @@ -215,30 +209,28 @@ func (k msgServer) RemoveOperatorFromAllowlist(ctx context.Context, msg *types.M // BorrowPoolSecurity defines the rpc method for Msg/BorrowPoolSecurity func (k msgServer) BorrowPoolSecurity(ctx context.Context, msg *types.MsgBorrowPoolSecurity) (*types.MsgBorrowPoolSecurityResponse, error) { // Ensure that the service exists - service, found, err := k.servicesKeeper.GetService(ctx, msg.ServiceID) + service, err := k.servicesKeeper.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, servicestypes.ErrServiceNotFound + } return nil, err } - if !found { - return nil, servicestypes.ErrServiceNotFound - } - // Ensure the service is active if !service.IsActive() { return nil, errors.Wrapf(servicestypes.ErrServiceNotActive, "service %d is not active", msg.ServiceID) } // Ensure that the pool exists - _, found, err = k.poolsKeeper.GetPool(ctx, msg.PoolID) + _, err = k.poolsKeeper.GetPool(ctx, msg.PoolID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, poolstypes.ErrPoolNotFound + } return nil, err } - if !found { - return nil, poolstypes.ErrPoolNotFound - } - // Ensure the service admin is performing this action if service.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, @@ -276,15 +268,14 @@ func (k msgServer) BorrowPoolSecurity(ctx context.Context, msg *types.MsgBorrowP // CeasePoolSecurityBorrow defines the rpc method for Msg/CeasePoolSecurityBorrow func (k msgServer) CeasePoolSecurityBorrow(ctx context.Context, msg *types.MsgCeasePoolSecurityBorrow) (*types.MsgCeasePoolSecurityBorrowResponse, error) { // Ensure that the service exists - service, found, err := k.servicesKeeper.GetService(ctx, msg.ServiceID) + service, err := k.servicesKeeper.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, servicestypes.ErrServiceNotFound + } return nil, err } - if !found { - return nil, servicestypes.ErrServiceNotFound - } - // Ensure the service admin is performing this action if service.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, diff --git a/x/restaking/keeper/operator_restaking.go b/x/restaking/keeper/operator_restaking.go index 514cbe2b..bd936bd8 100644 --- a/x/restaking/keeper/operator_restaking.go +++ b/x/restaking/keeper/operator_restaking.go @@ -81,15 +81,14 @@ func (k *Keeper) RemoveOperatorDelegation(ctx context.Context, delegation types. // DelegateToOperator sends the given amount to the operator account and saves the delegation for the given user func (k *Keeper) DelegateToOperator(ctx context.Context, operatorID uint32, amount sdk.Coins, delegator string) (sdk.DecCoins, error) { // Get the operator - operator, found, err := k.operatorsKeeper.GetOperator(ctx, operatorID) + operator, err := k.operatorsKeeper.GetOperator(ctx, operatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return sdk.NewDecCoins(), operatorstypes.ErrOperatorNotFound + } return nil, err } - if !found { - return sdk.NewDecCoins(), operatorstypes.ErrOperatorNotFound - } - restakableDenoms, err := k.GetRestakableDenoms(ctx) if err != nil { return nil, err @@ -164,15 +163,14 @@ func (k *Keeper) GetOperatorUnbondingDelegation(ctx context.Context, operatorID // unbonding delegation for the given user func (k *Keeper) UndelegateFromOperator(ctx context.Context, operatorID uint32, amount sdk.Coins, delegator string) (time.Time, error) { // Find the operator - operator, found, err := k.operatorsKeeper.GetOperator(ctx, operatorID) + operator, err := k.operatorsKeeper.GetOperator(ctx, operatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return time.Time{}, operatorstypes.ErrOperatorNotFound + } return time.Time{}, err } - if !found { - return time.Time{}, operatorstypes.ErrOperatorNotFound - } - // Get the shares shares, err := k.ValidateUnbondAmount(ctx, delegator, operator, amount) if err != nil { diff --git a/x/restaking/keeper/operator_restaking_test.go b/x/restaking/keeper/operator_restaking_test.go index b2f55c78..9e684145 100644 --- a/x/restaking/keeper/operator_restaking_test.go +++ b/x/restaking/keeper/operator_restaking_test.go @@ -367,9 +367,8 @@ func (suite *KeeperTestSuite) TestKeeper_DelegateToOperator() { ), check: func(ctx sdk.Context) { // Make sure the operator now exists - operator, found, err := suite.ok.GetOperator(ctx, 1) + operator, err := suite.ok.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(operatorstypes.Operator{ ID: 1, Status: operatorstypes.OPERATOR_STATUS_ACTIVE, @@ -455,9 +454,8 @@ func (suite *KeeperTestSuite) TestKeeper_DelegateToOperator() { ), check: func(ctx sdk.Context) { // Make sure the operator now exists - operator, found, err := suite.ok.GetOperator(ctx, 1) + operator, err := suite.ok.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(operatorstypes.Operator{ ID: 1, Status: operatorstypes.OPERATOR_STATUS_ACTIVE, @@ -562,9 +560,8 @@ func (suite *KeeperTestSuite) TestKeeper_DelegateToOperator() { ), check: func(ctx sdk.Context) { // Make sure the operator now exists - operator, found, err := suite.ok.GetOperator(ctx, 1) + operator, err := suite.ok.GetOperator(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(operatorstypes.Operator{ ID: 1, Status: operatorstypes.OPERATOR_STATUS_ACTIVE, diff --git a/x/restaking/keeper/operators_hooks.go b/x/restaking/keeper/operators_hooks.go index 0c3bd911..809f68b5 100644 --- a/x/restaking/keeper/operators_hooks.go +++ b/x/restaking/keeper/operators_hooks.go @@ -2,11 +2,13 @@ package keeper import ( "context" - "fmt" + "errors" "cosmossdk.io/collections" + errorsmod "cosmossdk.io/errors" operatorstypes "github.com/milkyway-labs/milkyway/v3/x/operators/types" + servicestypes "github.com/milkyway-labs/milkyway/v3/x/services/types" ) var _ operatorstypes.OperatorsHooks = &OperatorsHooks{} @@ -85,15 +87,14 @@ func (o *OperatorsHooks) removeOperatorFromServicesAllowList(ctx context.Context return err } if !isConfigured { - service, found, err := o.servicesKeeper.GetService(ctx, serviceID) + service, err := o.servicesKeeper.GetService(ctx, serviceID) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + return errorsmod.Wrapf(servicestypes.ErrServiceNotFound, "service %d not found", serviceID) + } return err } - if !found { - return fmt.Errorf("service %d not found", serviceID) - } - if !service.IsActive() { // The service is not active, nothing to do continue diff --git a/x/restaking/keeper/operators_hooks_test.go b/x/restaking/keeper/operators_hooks_test.go index 736473bc..ed0b9c73 100644 --- a/x/restaking/keeper/operators_hooks_test.go +++ b/x/restaking/keeper/operators_hooks_test.go @@ -82,9 +82,8 @@ func (suite *KeeperTestSuite) TestOperatorHooks_BeforeOperatorDeleted() { suite.Assert().False(joined) // Ensure that the service is status has not changed - service, found, err := suite.sk.GetService(ctx, 2) + service, err := suite.sk.GetService(ctx, 2) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(servicestypes.SERVICE_STATUS_ACTIVE, service.Status) }, operatorID: 1, @@ -131,9 +130,8 @@ func (suite *KeeperTestSuite) TestOperatorHooks_BeforeOperatorDeleted() { suite.Assert().True(joined) // Ensure that the service is status has not changed - service, found, err := suite.sk.GetService(ctx, 2) + service, err := suite.sk.GetService(ctx, 2) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(servicestypes.SERVICE_STATUS_ACTIVE, service.Status) }, operatorID: 1, @@ -164,9 +162,8 @@ func (suite *KeeperTestSuite) TestOperatorHooks_BeforeOperatorDeleted() { suite.Assert().False(joined) // Ensure that the service is now inactive - service, found, err := suite.sk.GetService(ctx, 1) + service, err := suite.sk.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(servicestypes.SERVICE_STATUS_INACTIVE, service.Status) }, operatorID: 1, @@ -200,9 +197,8 @@ func (suite *KeeperTestSuite) TestOperatorHooks_BeforeOperatorDeleted() { suite.Assert().False(joined) // Ensure the service status has not changed - service, found, err := suite.sk.GetService(ctx, 1) + service, err := suite.sk.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(servicestypes.SERVICE_STATUS_ACTIVE, service.Status) }, operatorID: 1, @@ -233,9 +229,8 @@ func (suite *KeeperTestSuite) TestOperatorHooks_BeforeOperatorDeleted() { suite.Assert().False(joined) // Ensure the service status has not changed - service, found, err := suite.sk.GetService(ctx, 1) + service, err := suite.sk.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(servicestypes.SERVICE_STATUS_CREATED, service.Status) }, operatorID: 1, @@ -266,9 +261,8 @@ func (suite *KeeperTestSuite) TestOperatorHooks_BeforeOperatorDeleted() { suite.Assert().False(joined) // Ensure the service status has not changed - service, found, err := suite.sk.GetService(ctx, 1) + service, err := suite.sk.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(servicestypes.SERVICE_STATUS_INACTIVE, service.Status) }, operatorID: 1, diff --git a/x/restaking/keeper/pool_restaking_test.go b/x/restaking/keeper/pool_restaking_test.go index b7f525c4..2d89730b 100644 --- a/x/restaking/keeper/pool_restaking_test.go +++ b/x/restaking/keeper/pool_restaking_test.go @@ -308,9 +308,8 @@ func (suite *KeeperTestSuite) TestKeeper_DelegateToPool() { expShares: sdk.NewDecCoins(sdk.NewDecCoinFromDec("pool/1/umilk", sdkmath.LegacyNewDec(100))), check: func(ctx sdk.Context) { // Make sure the pool now exists - pool, found, err := suite.pk.GetPool(ctx, 1) + pool, err := suite.pk.GetPool(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(poolstypes.Pool{ ID: 1, Denom: "umilk", @@ -374,9 +373,8 @@ func (suite *KeeperTestSuite) TestKeeper_DelegateToPool() { expShares: sdk.NewDecCoins(sdk.NewDecCoinFromDec("pool/1/umilk", sdkmath.LegacyNewDec(500))), check: func(ctx sdk.Context) { // Make sure the pool now exists - pool, found, err := suite.pk.GetPool(ctx, 1) + pool, err := suite.pk.GetPool(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(poolstypes.Pool{ ID: 1, Denom: "umilk", @@ -448,9 +446,8 @@ func (suite *KeeperTestSuite) TestKeeper_DelegateToPool() { expShares: sdk.NewDecCoins(sdk.NewDecCoinFromDec("pool/1/umilk", sdkmath.LegacyNewDecWithPrec(15625, 2))), check: func(ctx sdk.Context) { // Make sure the pool now exists - pool, found, err := suite.pk.GetPool(ctx, 1) + pool, err := suite.pk.GetPool(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(poolstypes.Pool{ ID: 1, Denom: "umilk", diff --git a/x/restaking/keeper/service_restaking.go b/x/restaking/keeper/service_restaking.go index 0ad61821..d9226c50 100644 --- a/x/restaking/keeper/service_restaking.go +++ b/x/restaking/keeper/service_restaking.go @@ -27,8 +27,8 @@ func (k *Keeper) GetAllServiceAllowedOperators(ctx context.Context, serviceID ui if err != nil { return nil, err } - defer iterator.Close() + var operators []uint32 for ; iterator.Valid(); iterator.Next() { serviceOperatorPair, err := iterator.Key() @@ -64,12 +64,7 @@ func (k *Keeper) IsServiceOperatorsAllowListConfigured(ctx context.Context, serv return false, err } defer iterator.Close() - - for ; iterator.Valid(); iterator.Next() { - return true, nil - } - - return false, nil + return iterator.Valid(), nil } // IsOperatorInServiceAllowList returns true if the given operator is in the @@ -145,12 +140,7 @@ func (k *Keeper) IsServiceSecuringPoolsConfigured(ctx context.Context, serviceID return false, err } defer iterator.Close() - - for ; iterator.Valid(); iterator.Next() { - return true, nil - } - - return false, nil + return iterator.Valid(), nil } // IsPoolInServiceSecuringPools returns true if the pool is in the list @@ -224,15 +214,14 @@ func (k *Keeper) RemoveServiceDelegation(ctx context.Context, delegation types.D // DelegateToService sends the given amount to the service account and saves the delegation for the given user func (k *Keeper) DelegateToService(ctx context.Context, serviceID uint32, amount sdk.Coins, delegator string) (sdk.DecCoins, error) { // Get the service - service, found, err := k.servicesKeeper.GetService(ctx, serviceID) + service, err := k.servicesKeeper.GetService(ctx, serviceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return sdk.NewDecCoins(), servicestypes.ErrServiceNotFound + } return nil, err } - if !found { - return sdk.NewDecCoins(), servicestypes.ErrServiceNotFound - } - restakableDenoms, err := k.GetRestakableDenoms(ctx) if err != nil { return nil, err @@ -327,15 +316,14 @@ func (k *Keeper) GetServiceUnbondingDelegation(ctx context.Context, serviceID ui // unbonding delegation for the given user func (k *Keeper) UndelegateFromService(ctx context.Context, serviceID uint32, amount sdk.Coins, delegator string) (time.Time, error) { // Find the service - service, found, err := k.servicesKeeper.GetService(ctx, serviceID) + service, err := k.servicesKeeper.GetService(ctx, serviceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return time.Time{}, servicestypes.ErrServiceNotFound + } return time.Time{}, err } - if !found { - return time.Time{}, servicestypes.ErrServiceNotFound - } - // Get the shares shares, err := k.ValidateUnbondAmount(ctx, delegator, service, amount) if err != nil { diff --git a/x/restaking/keeper/service_restaking_test.go b/x/restaking/keeper/service_restaking_test.go index 7fed170e..eae5eabd 100644 --- a/x/restaking/keeper/service_restaking_test.go +++ b/x/restaking/keeper/service_restaking_test.go @@ -717,9 +717,8 @@ func (suite *KeeperTestSuite) TestKeeper_DelegateToService() { expShares: sdk.NewDecCoins(sdk.NewDecCoinFromDec("service/1/umilk", sdkmath.LegacyNewDec(500))), check: func(ctx sdk.Context) { // Make sure the service now exists - service, found, err := suite.sk.GetService(ctx, 1) + service, err := suite.sk.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(servicestypes.Service{ ID: 1, Status: servicestypes.SERVICE_STATUS_ACTIVE, @@ -804,9 +803,8 @@ func (suite *KeeperTestSuite) TestKeeper_DelegateToService() { expShares: sdk.NewDecCoins(sdk.NewDecCoinFromDec("service/1/uinit", sdkmath.LegacyNewDec(100))), check: func(ctx sdk.Context) { // Make sure the service now exists - service, found, err := suite.sk.GetService(ctx, 1) + service, err := suite.sk.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(servicestypes.Service{ ID: 1, Status: servicestypes.SERVICE_STATUS_ACTIVE, @@ -912,9 +910,8 @@ func (suite *KeeperTestSuite) TestKeeper_DelegateToService() { ), check: func(ctx sdk.Context) { // Make sure the service now exists - service, found, err := suite.sk.GetService(ctx, 1) + service, err := suite.sk.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(servicestypes.Service{ ID: 1, Status: servicestypes.SERVICE_STATUS_ACTIVE, diff --git a/x/restaking/keeper/services_hooks.go b/x/restaking/keeper/services_hooks.go index 54485ae7..d095ecf4 100644 --- a/x/restaking/keeper/services_hooks.go +++ b/x/restaking/keeper/services_hooks.go @@ -46,7 +46,7 @@ func (h *ServicesHooks) BeforeServiceDeleted(ctx context.Context, serviceID uint // Get the iterator to iterate over the operators that are // allowed to secure this service - serviceOperatorsAllowListIter, err := h.serviceOperatorsAllowList.Iterate(ctx, collections.NewPrefixedPairRange[uint32, uint32](serviceID)) + serviceOperatorsAllowListIter, err := h.ServiceAllowedOperatorsIterator(ctx, serviceID) if err != nil { return err } @@ -66,7 +66,7 @@ func (h *ServicesHooks) BeforeServiceDeleted(ctx context.Context, serviceID uint // Get the iterator to iterate over the list of pools from // which the service is allowed to borrow security - serviceSecuringPoolsIter, err := h.serviceSecuringPools.Iterate(ctx, collections.NewPrefixedPairRange[uint32, uint32](serviceID)) + serviceSecuringPoolsIter, err := h.ServiceSecuringPoolsIterator(ctx, serviceID) if err != nil { return err } diff --git a/x/restaking/types/expected_keepers.go b/x/restaking/types/expected_keepers.go index 0d196961..554ca175 100644 --- a/x/restaking/types/expected_keepers.go +++ b/x/restaking/types/expected_keepers.go @@ -24,14 +24,14 @@ type BankKeeper interface { type PoolsKeeper interface { GetPoolByDenom(ctx context.Context, denom string) (poolstypes.Pool, bool, error) CreateOrGetPoolByDenom(ctx context.Context, denom string) (poolstypes.Pool, error) - GetPool(ctx context.Context, poolID uint32) (poolstypes.Pool, bool, error) + GetPool(ctx context.Context, poolID uint32) (poolstypes.Pool, error) SavePool(ctx context.Context, pool poolstypes.Pool) error IteratePools(ctx context.Context, cb func(poolstypes.Pool) (bool, error)) error GetPools(ctx context.Context) ([]poolstypes.Pool, error) } type OperatorsKeeper interface { - GetOperator(ctx context.Context, operatorID uint32) (operatorstypes.Operator, bool, error) + GetOperator(ctx context.Context, operatorID uint32) (operatorstypes.Operator, error) SaveOperator(ctx context.Context, operator operatorstypes.Operator) error IterateOperators(ctx context.Context, cb func(operatorstypes.Operator) (bool, error)) error GetOperators(ctx context.Context) ([]operatorstypes.Operator, error) @@ -41,7 +41,7 @@ type OperatorsKeeper interface { type ServicesKeeper interface { HasService(ctx context.Context, serviceID uint32) (bool, error) - GetService(ctx context.Context, serviceID uint32) (servicestypes.Service, bool, error) + GetService(ctx context.Context, serviceID uint32) (servicestypes.Service, error) SaveService(ctx context.Context, service servicestypes.Service) error IterateServices(ctx context.Context, cb func(servicestypes.Service) (bool, error)) error GetServices(ctx context.Context) ([]servicestypes.Service, error) diff --git a/x/rewards/keeper/allocation.go b/x/rewards/keeper/allocation.go index 09330e43..b98ba5b4 100644 --- a/x/rewards/keeper/allocation.go +++ b/x/rewards/keeper/allocation.go @@ -197,15 +197,14 @@ func (k *Keeper) AllocateRewardsByPlan( return err } - service, found, err := k.servicesKeeper.GetService(ctx, plan.ServiceID) + service, err := k.servicesKeeper.GetService(ctx, plan.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return servicestypes.ErrServiceNotFound + } return err } - if !found { - return servicestypes.ErrServiceNotFound - } - // Ensure that we are distribution rewards only for active services if !service.IsActive() { return nil diff --git a/x/rewards/keeper/allocation_test.go b/x/rewards/keeper/allocation_test.go index fe0d4845..b4839492 100644 --- a/x/rewards/keeper/allocation_test.go +++ b/x/rewards/keeper/allocation_test.go @@ -970,7 +970,7 @@ func (suite *KeeperTestSuite) TestAllocateRewards_InactiveOperator() { suite.Require().NoError(err) // Refresh the updated state of operator 2. - operator2, _, err = suite.operatorsKeeper.GetOperator(ctx, operator2.ID) + operator2, err = suite.operatorsKeeper.GetOperator(ctx, operator2.ID) suite.Require().NoError(err) // Operator 2 becomes inactive. err = suite.operatorsKeeper.StartOperatorInactivation(ctx, operator2) diff --git a/x/rewards/keeper/common_test.go b/x/rewards/keeper/common_test.go index d5a96309..11d92dbe 100644 --- a/x/rewards/keeper/common_test.go +++ b/x/rewards/keeper/common_test.go @@ -145,9 +145,8 @@ func (suite *KeeperTestSuite) CreateService(ctx sdk.Context, name string, admin _, err = servicesMsgServer.ActivateService(ctx, servicestypes.NewMsgActivateService(resp.NewServiceID, admin)) suite.Require().NoError(err) - service, found, err := suite.servicesKeeper.GetService(ctx, resp.NewServiceID) + service, err := suite.servicesKeeper.GetService(ctx, resp.NewServiceID) suite.Require().NoError(err) - suite.Require().True(found, "service must be found") return service } @@ -167,9 +166,8 @@ func (suite *KeeperTestSuite) CreateOperator(ctx sdk.Context, name string, admin suite.Require().NoError(err) // Make sure the operator is found - operator, found, err := suite.operatorsKeeper.GetOperator(ctx, resp.NewOperatorID) + operator, err := suite.operatorsKeeper.GetOperator(ctx, resp.NewOperatorID) suite.Require().NoError(err) - suite.Require().True(found, "operator must be found") return operator } @@ -182,9 +180,8 @@ func (suite *KeeperTestSuite) UpdateOperatorParams( joinedServicesIDs []uint32, ) { // Make sure the operator is found - _, found, err := suite.operatorsKeeper.GetOperator(ctx, operatorID) + _, err := suite.operatorsKeeper.GetOperator(ctx, operatorID) suite.Require().NoError(err) - suite.Require().True(found, "operator must be found") // Sets the operator commission rate err = suite.operatorsKeeper.SaveOperatorParams(ctx, operatorID, operatorstypes.NewOperatorParams(commissionRate)) @@ -221,9 +218,8 @@ func (suite *KeeperTestSuite) AddPoolsToServiceSecuringPools( whitelistedPoolsIDs []uint32, ) { // Make sure the service is found - _, found, err := suite.servicesKeeper.GetService(ctx, serviceID) + _, err := suite.servicesKeeper.GetService(ctx, serviceID) suite.Require().NoError(err) - suite.Require().True(found, "service must be found") for _, poolID := range whitelistedPoolsIDs { err := suite.restakingKeeper.AddPoolToServiceSecuringPools(ctx, serviceID, poolID) @@ -239,9 +235,8 @@ func (suite *KeeperTestSuite) AddOperatorsToServiceAllowList( allowedOperatorsID []uint32, ) { // Make sure the service is found - _, found, err := suite.servicesKeeper.GetService(ctx, serviceID) + _, err := suite.servicesKeeper.GetService(ctx, serviceID) suite.Require().NoError(err) - suite.Require().True(found, "service must be found") for _, operatorID := range allowedOperatorsID { err := suite.restakingKeeper.AddOperatorToServiceAllowList(ctx, serviceID, operatorID) @@ -263,9 +258,8 @@ func (suite *KeeperTestSuite) CreateRewardsPlan( usersDistr rewardstypes.UsersDistribution, initialRewards sdk.Coins, ) rewardstypes.RewardsPlan { - service, found, err := suite.servicesKeeper.GetService(ctx, serviceID) + service, err := suite.servicesKeeper.GetService(ctx, serviceID) suite.Require().NoError(err) - suite.Require().True(found, "service must be found") rewardsMsgServer := keeper.NewMsgServer(suite.keeper) resp, err := rewardsMsgServer.CreateRewardsPlan(ctx, rewardstypes.NewMsgCreateRewardsPlan( diff --git a/x/rewards/keeper/hooks.go b/x/rewards/keeper/hooks.go index ff1e17c7..09553f39 100644 --- a/x/rewards/keeper/hooks.go +++ b/x/rewards/keeper/hooks.go @@ -2,7 +2,9 @@ package keeper import ( "context" + "errors" + "cosmossdk.io/collections" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" restakingtypes "github.com/milkyway-labs/milkyway/v3/x/restaking/types" @@ -133,15 +135,14 @@ func (k *Keeper) AfterDelegationModified(ctx context.Context, delType restakingt // AfterServiceAccreditationModified implements servicestypes.ServicesHooks func (k *Keeper) AfterServiceAccreditationModified(ctx context.Context, serviceID uint32) error { - service, found, err := k.servicesKeeper.GetService(ctx, serviceID) + service, err := k.servicesKeeper.GetService(ctx, serviceID) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + return servicestypes.ErrServiceNotFound + } return err } - if !found { - return servicestypes.ErrServiceNotFound - } - err = k.restakingKeeper.IterateServiceDelegations(ctx, serviceID, func(del restakingtypes.Delegation) (stop bool, err error) { preferences, err := k.restakingKeeper.GetUserPreferences(ctx, del.UserAddress) if err != nil { diff --git a/x/rewards/keeper/keeper.go b/x/rewards/keeper/keeper.go index 2a15326f..e8665846 100644 --- a/x/rewards/keeper/keeper.go +++ b/x/rewards/keeper/keeper.go @@ -16,8 +16,7 @@ import ( ) type Keeper struct { - cdc codec.Codec - storeService corestoretypes.KVStoreService + cdc codec.Codec accountKeeper types.AccountKeeper bankKeeper types.BankKeeper @@ -74,7 +73,6 @@ func NewKeeper( sb := collections.NewSchemaBuilder(storeService) k := &Keeper{ cdc: cdc, - storeService: storeService, accountKeeper: accountKeeper, bankKeeper: bankKeeper, communityPoolKeeper: communityPoolKeeper, diff --git a/x/rewards/keeper/msg_server.go b/x/rewards/keeper/msg_server.go index 053359ee..43aef178 100644 --- a/x/rewards/keeper/msg_server.go +++ b/x/rewards/keeper/msg_server.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "cosmossdk.io/collections" "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" @@ -27,15 +28,14 @@ func NewMsgServer(k *Keeper) types.MsgServer { // CreateRewardsPlan defines the rpc method for Msg/CreateRewardsPlan func (k msgServer) CreateRewardsPlan(ctx context.Context, msg *types.MsgCreateRewardsPlan) (*types.MsgCreateRewardsPlanResponse, error) { // Make sure the creator is the admin of the service - service, found, err := k.servicesKeeper.GetService(ctx, msg.ServiceID) + service, err := k.servicesKeeper.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, servicestypes.ErrServiceNotFound + } return nil, err } - if !found { - return nil, servicestypes.ErrServiceNotFound - } - if msg.Sender != service.Admin { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only service admin can create rewards plan") } @@ -103,15 +103,14 @@ func (k msgServer) EditRewardsPlan(ctx context.Context, msg *types.MsgEditReward } // Get the service to which the rewards is associated - service, found, err := k.servicesKeeper.GetService(ctx, rewardsPlan.ServiceID) + service, err := k.servicesKeeper.GetService(ctx, rewardsPlan.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, servicestypes.ErrServiceNotFound + } return nil, err } - if !found { - return nil, servicestypes.ErrServiceNotFound - } - // Make sure the editor is the admin of the service if msg.Sender != service.Admin { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only service admin can create rewards plan") @@ -212,15 +211,14 @@ func (k msgServer) WithdrawOperatorCommission(ctx context.Context, msg *types.Ms return nil, sdkerrors.ErrInvalidAddress.Wrapf("invalid sender address: %s", err) } - operator, found, err := k.operatorsKeeper.GetOperator(ctx, msg.OperatorID) + operator, err := k.operatorsKeeper.GetOperator(ctx, msg.OperatorID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, operatorstypes.ErrOperatorNotFound + } return nil, err } - if !found { - return nil, operatorstypes.ErrOperatorNotFound - } - if msg.Sender != operator.Admin { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only operator admin can withdraw operator commission") } diff --git a/x/rewards/keeper/rewards_plan.go b/x/rewards/keeper/rewards_plan.go index 3eda7d43..e86306a4 100644 --- a/x/rewards/keeper/rewards_plan.go +++ b/x/rewards/keeper/rewards_plan.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + "cosmossdk.io/collections" "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" @@ -24,15 +25,14 @@ func (k *Keeper) CreateRewardsPlan( operatorsDistribution types.Distribution, usersDistribution types.UsersDistribution, ) (types.RewardsPlan, error) { - _, found, err := k.servicesKeeper.GetService(ctx, serviceID) + _, err := k.servicesKeeper.GetService(ctx, serviceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return types.RewardsPlan{}, servicestypes.ErrServiceNotFound + } return types.RewardsPlan{}, err } - if !found { - return types.RewardsPlan{}, servicestypes.ErrServiceNotFound - } - // Get the plan id to be used planID, err := k.NextRewardsPlanID.Get(ctx) if err != nil { @@ -130,15 +130,14 @@ func (k *Keeper) terminateRewardsPlan(ctx context.Context, plan types.RewardsPla remaining := k.bankKeeper.GetAllBalances(ctx, rewardsPoolAddr) if remaining.IsAllPositive() { // Get the service's address. - service, found, err := k.servicesKeeper.GetService(ctx, plan.ServiceID) + service, err := k.servicesKeeper.GetService(ctx, plan.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return servicestypes.ErrServiceNotFound + } return err } - if !found { - return servicestypes.ErrServiceNotFound - } - serviceAddr, err := k.accountKeeper.AddressCodec().StringToBytes(service.Address) if err != nil { return err diff --git a/x/rewards/keeper/target.go b/x/rewards/keeper/target.go index 1737ee03..24045f14 100644 --- a/x/rewards/keeper/target.go +++ b/x/rewards/keeper/target.go @@ -35,13 +35,13 @@ func (k *Keeper) GetDelegationTarget( ) (DelegationTarget, error) { switch delType { case restakingtypes.DELEGATION_TYPE_POOL: - pool, found, err := k.poolsKeeper.GetPool(ctx, targetID) + pool, err := k.poolsKeeper.GetPool(ctx, targetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return DelegationTarget{}, poolstypes.ErrPoolNotFound + } return DelegationTarget{}, err } - if !found { - return DelegationTarget{}, poolstypes.ErrPoolNotFound - } return DelegationTarget{ DelegationTarget: pool, DelegationType: delType, @@ -51,13 +51,13 @@ func (k *Keeper) GetDelegationTarget( OutstandingRewards: k.PoolOutstandingRewards, }, nil case restakingtypes.DELEGATION_TYPE_OPERATOR: - operator, found, err := k.operatorsKeeper.GetOperator(ctx, targetID) + operator, err := k.operatorsKeeper.GetOperator(ctx, targetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return DelegationTarget{}, operatorstypes.ErrOperatorNotFound + } return DelegationTarget{}, err } - if !found { - return DelegationTarget{}, operatorstypes.ErrOperatorNotFound - } return DelegationTarget{ DelegationTarget: operator, DelegationType: delType, @@ -67,13 +67,13 @@ func (k *Keeper) GetDelegationTarget( OutstandingRewards: k.OperatorOutstandingRewards, }, nil case restakingtypes.DELEGATION_TYPE_SERVICE: - service, found, err := k.servicesKeeper.GetService(ctx, targetID) + service, err := k.servicesKeeper.GetService(ctx, targetID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return DelegationTarget{}, servicestypes.ErrServiceNotFound + } return DelegationTarget{}, err } - if !found { - return DelegationTarget{}, servicestypes.ErrServiceNotFound - } return DelegationTarget{ DelegationTarget: service, DelegationType: delType, diff --git a/x/rewards/keeper/withdraw.go b/x/rewards/keeper/withdraw.go index 2aafa757..54cb7864 100644 --- a/x/rewards/keeper/withdraw.go +++ b/x/rewards/keeper/withdraw.go @@ -2,8 +2,10 @@ package keeper import ( "context" + "errors" "fmt" + "cosmossdk.io/collections" errorsmod "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" @@ -63,15 +65,14 @@ func (k *Keeper) WithdrawDelegationRewards( // WithdrawOperatorCommission withdraws the operator's accumulated commission func (k *Keeper) WithdrawOperatorCommission(ctx context.Context, operatorID uint32) (types.Pools, error) { - operator, found, err := k.operatorsKeeper.GetOperator(ctx, operatorID) + operator, err := k.operatorsKeeper.GetOperator(ctx, operatorID) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + return nil, operatorstypes.ErrOperatorNotFound + } return nil, err } - if !found { - return nil, operatorstypes.ErrOperatorNotFound - } - // Fetch the operator accumulated commission accumCommission, err := k.OperatorAccumulatedCommissions.Get(ctx, operatorID) if err != nil { diff --git a/x/rewards/types/expected_keepers.go b/x/rewards/types/expected_keepers.go index 1a394403..88d41cd2 100644 --- a/x/rewards/types/expected_keepers.go +++ b/x/rewards/types/expected_keepers.go @@ -43,20 +43,20 @@ type OracleKeeper interface { } type PoolsKeeper interface { - GetPool(ctx context.Context, poolID uint32) (poolstypes.Pool, bool, error) + GetPool(ctx context.Context, poolID uint32) (poolstypes.Pool, error) GetPools(ctx context.Context) ([]poolstypes.Pool, error) IteratePools(ctx context.Context, cb func(pool poolstypes.Pool) (stop bool, err error)) error } type OperatorsKeeper interface { - GetOperator(ctx context.Context, operatorID uint32) (operatorstypes.Operator, bool, error) + GetOperator(ctx context.Context, operatorID uint32) (operatorstypes.Operator, error) GetOperators(ctx context.Context) ([]operatorstypes.Operator, error) IterateOperators(ctx context.Context, cb func(operator operatorstypes.Operator) (stop bool, err error)) error GetOperatorParams(ctx context.Context, operatorID uint32) (operatorstypes.OperatorParams, error) } type ServicesKeeper interface { - GetService(ctx context.Context, serviceID uint32) (servicestypes.Service, bool, error) + GetService(ctx context.Context, serviceID uint32) (servicestypes.Service, error) GetServiceParams(ctx context.Context, serviceID uint32) (servicestypes.ServiceParams, error) IterateServices(ctx context.Context, cb func(service servicestypes.Service) (stop bool, err error)) error } diff --git a/x/services/keeper/alias_functions.go b/x/services/keeper/alias_functions.go index ce29f37e..45375a1c 100644 --- a/x/services/keeper/alias_functions.go +++ b/x/services/keeper/alias_functions.go @@ -19,29 +19,10 @@ func (k *Keeper) createAccountIfNotExists(ctx context.Context, address sdk.AccAd // IterateServices iterates over the services in the store and performs a callback function func (k *Keeper) IterateServices(ctx context.Context, cb func(service types.Service) (stop bool, err error)) error { - iterator, err := k.services.Iterate(ctx, nil) - if err != nil { - return err - } - defer iterator.Close() - - for ; iterator.Valid(); iterator.Next() { - service, err := iterator.Value() - if err != nil { - return err - } - - stop, err := cb(service) - if err != nil { - return err - } - - if stop { - break - } - } - - return nil + err := k.services.Walk(ctx, nil, func(_ uint32, service types.Service) (stop bool, err error) { + return cb(service) + }) + return err } // GetServices returns the services stored in the KVStore diff --git a/x/services/keeper/grpc_query.go b/x/services/keeper/grpc_query.go index 85998faf..2c0ba2a6 100644 --- a/x/services/keeper/grpc_query.go +++ b/x/services/keeper/grpc_query.go @@ -2,7 +2,9 @@ package keeper import ( "context" + "errors" + "cosmossdk.io/collections" "github.com/cosmos/cosmos-sdk/types/query" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -18,15 +20,14 @@ func (k *Keeper) Service(ctx context.Context, request *types.QueryServiceRequest return nil, status.Error(codes.InvalidArgument, "invalid service ID") } - service, found, err := k.GetService(ctx, request.ServiceId) + service, err := k.GetService(ctx, request.ServiceId) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + return nil, status.Error(codes.NotFound, "service not found") + } return nil, status.Error(codes.Internal, err.Error()) } - if !found { - return nil, status.Error(codes.NotFound, "service not found") - } - return &types.QueryServiceResponse{Service: service}, nil } @@ -54,15 +55,14 @@ func (k *Keeper) ServiceParams(ctx context.Context, request *types.QueryServiceP } // Ensure the service exists - _, found, err := k.GetService(ctx, request.ServiceId) + _, err := k.GetService(ctx, request.ServiceId) if err != nil { + if errors.Is(err, collections.ErrNotFound) { + return nil, status.Error(codes.NotFound, "service not found") + } return nil, status.Error(codes.Internal, err.Error()) } - if !found { - return nil, status.Error(codes.NotFound, "service not found") - } - // Get the service params serviceParams, err := k.GetServiceParams(ctx, request.ServiceId) if err != nil { diff --git a/x/services/keeper/keeper.go b/x/services/keeper/keeper.go index 8aa7d594..2028a9db 100644 --- a/x/services/keeper/keeper.go +++ b/x/services/keeper/keeper.go @@ -19,8 +19,7 @@ type Keeper struct { accountKeeper types.AccountKeeper poolKeeper types.CommunityPoolKeeper - storeService corestoretypes.KVStoreService - Schema collections.Schema + Schema collections.Schema nextServiceID collections.Sequence // Next service ID services collections.Map[uint32, types.Service] // service ID -> service @@ -48,7 +47,6 @@ func NewKeeper( accountKeeper: accountKeeper, poolKeeper: poolKeeper, authority: authority, - storeService: storeService, nextServiceID: collections.NewSequence( sb, diff --git a/x/services/keeper/msg_server.go b/x/services/keeper/msg_server.go index cc7d2a8f..def54956 100644 --- a/x/services/keeper/msg_server.go +++ b/x/services/keeper/msg_server.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "cosmossdk.io/collections" "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" @@ -104,15 +105,14 @@ func (k msgServer) CreateService(goCtx context.Context, msg *types.MsgCreateServ // UpdateService defines the rpc method for Msg/UpdateService func (k msgServer) UpdateService(ctx context.Context, msg *types.MsgUpdateService) (*types.MsgUpdateServiceResponse, error) { // Check if the service exists - service, found, err := k.GetService(ctx, msg.ServiceID) + service, err := k.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, types.ErrServiceNotFound + } return nil, err } - if !found { - return nil, errors.Wrapf(sdkerrors.ErrInvalidRequest, "service with id %d not found", msg.ServiceID) - } - // Make sure the user that is updating the service is the admin if service.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can update the service") @@ -146,15 +146,14 @@ func (k msgServer) UpdateService(ctx context.Context, msg *types.MsgUpdateServic func (k msgServer) ActivateService(ctx context.Context, msg *types.MsgActivateService) (*types.MsgActivateServiceResponse, error) { // Check if the service exists - service, found, err := k.GetService(ctx, msg.ServiceID) + service, err := k.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, types.ErrServiceNotFound + } return nil, err } - if !found { - return nil, errors.Wrapf(types.ErrServiceNotFound, "service with id %d not found", msg.ServiceID) - } - // Make sure the user that is activating the service is the admin if service.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can activate the service") @@ -181,15 +180,14 @@ func (k msgServer) ActivateService(ctx context.Context, msg *types.MsgActivateSe // DeactivateService defines the rpc method for Msg/DeactivateService func (k msgServer) DeactivateService(ctx context.Context, msg *types.MsgDeactivateService) (*types.MsgDeactivateServiceResponse, error) { // Check if the service exists - service, found, err := k.GetService(ctx, msg.ServiceID) + service, err := k.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, types.ErrServiceNotFound + } return nil, err } - if !found { - return nil, errors.Wrapf(types.ErrServiceNotFound, "service with id %d not found", msg.ServiceID) - } - // Make sure the user that is deactivating the service is the admin if service.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can deactivate the service") @@ -215,15 +213,14 @@ func (k msgServer) DeactivateService(ctx context.Context, msg *types.MsgDeactiva func (k msgServer) DeleteService(ctx context.Context, msg *types.MsgDeleteService) (*types.MsgDeleteServiceResponse, error) { // Check if the service exists - service, found, err := k.GetService(ctx, msg.ServiceID) + service, err := k.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, types.ErrServiceNotFound + } return nil, err } - if !found { - return nil, errors.Wrapf(types.ErrServiceNotFound, "service with id %d not found", msg.ServiceID) - } - // Make sure the user that is deleting the service is the admin if service.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can delete the service") @@ -250,15 +247,14 @@ func (k msgServer) DeleteService(ctx context.Context, msg *types.MsgDeleteServic // TransferServiceOwnership defines the rpc method for Msg/TransferServiceOwnership func (k msgServer) TransferServiceOwnership(ctx context.Context, msg *types.MsgTransferServiceOwnership) (*types.MsgTransferServiceOwnershipResponse, error) { // Check if the service exists - service, found, err := k.GetService(ctx, msg.ServiceID) + service, err := k.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, types.ErrServiceNotFound + } return nil, err } - if !found { - return nil, types.ErrServiceNotFound - } - // Make sure only the admin can transfer the service ownership if service.Admin != msg.Sender { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "only the admin can transfer the service ownership") @@ -286,15 +282,14 @@ func (k msgServer) TransferServiceOwnership(ctx context.Context, msg *types.MsgT // SetServiceParams define the rpc method for Msg/SetServiceParams func (k msgServer) SetServiceParams(ctx context.Context, msg *types.MsgSetServiceParams) (*types.MsgSetServiceParamsResponse, error) { // Get the service whose params are being set - service, found, err := k.GetService(ctx, msg.ServiceID) + service, err := k.GetService(ctx, msg.ServiceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return nil, types.ErrServiceNotFound + } return nil, err } - if !found { - return nil, types.ErrServiceNotFound - } - // Ensure the sender is the service admin if msg.Sender != service.Admin { return nil, errors.Wrapf(sdkerrors.ErrUnauthorized, "sender must be the service admin") diff --git a/x/services/keeper/msg_server_test.go b/x/services/keeper/msg_server_test.go index a2148edb..fa27d524 100644 --- a/x/services/keeper/msg_server_test.go +++ b/x/services/keeper/msg_server_test.go @@ -1,6 +1,7 @@ package keeper_test import ( + "cosmossdk.io/collections" sdkmath "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" @@ -118,9 +119,8 @@ func (suite *KeeperTestSuite) TestMsgServer_CreateService() { }, check: func(ctx sdk.Context) { // Make sure the service has been stored - stored, found, err := suite.k.GetService(ctx, 1) + stored, err := suite.k.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewService( 1, types.SERVICE_STATUS_CREATED, @@ -200,9 +200,8 @@ func (suite *KeeperTestSuite) TestMsgServer_CreateService() { }, check: func(ctx sdk.Context) { // Make sure the service has been stored - stored, found, err := suite.k.GetService(ctx, 1) + stored, err := suite.k.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewService( 1, types.SERVICE_STATUS_CREATED, @@ -386,9 +385,8 @@ func (suite *KeeperTestSuite) TestMsgServer_UpdateService() { }, check: func(ctx sdk.Context) { // Make sure the service was updated - stored, found, err := suite.k.GetService(ctx, 1) + stored, err := suite.k.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewService( 1, types.SERVICE_STATUS_CREATED, @@ -632,9 +630,8 @@ func (suite *KeeperTestSuite) TestMsgServer_DeactivateService() { }, check: func(ctx sdk.Context) { // Make sure the service was deactivated - stored, found, err := suite.k.GetService(ctx, 1) + stored, err := suite.k.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewService( 1, types.SERVICE_STATUS_INACTIVE, @@ -770,9 +767,8 @@ func (suite *KeeperTestSuite) TestMsgServer_DeleteService() { }, check: func(ctx sdk.Context) { // Make sure the service was removed - _, found, err := suite.k.GetService(ctx, 1) - suite.Require().NoError(err) - suite.Require().False(found) + _, err := suite.k.GetService(ctx, 1) + suite.Require().ErrorIs(err, collections.ErrNotFound) }, }, { @@ -803,9 +799,8 @@ func (suite *KeeperTestSuite) TestMsgServer_DeleteService() { }, check: func(ctx sdk.Context) { // Make sure the service was removed - _, found, err := suite.k.GetService(ctx, 1) - suite.Require().NoError(err) - suite.Require().False(found) + _, err := suite.k.GetService(ctx, 1) + suite.Require().ErrorIs(err, collections.ErrNotFound) }, }, } @@ -914,9 +909,8 @@ func (suite *KeeperTestSuite) TestMsgServer_TransferServiceOwnership() { }, check: func(ctx sdk.Context) { // Make sure the service was updated - stored, found, err := suite.k.GetService(ctx, 1) + stored, err := suite.k.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewService( 1, types.SERVICE_STATUS_ACTIVE, @@ -1228,9 +1222,8 @@ func (suite *KeeperTestSuite) TestMsgServer_AccreditService() { ), }, check: func(ctx sdk.Context) { - service, found, err := suite.k.GetService(ctx, 1) + service, err := suite.k.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().True(service.Accredited) }, }, @@ -1325,9 +1318,8 @@ func (suite *KeeperTestSuite) TestMsgService_RevokeServiceAccreditation() { ), }, check: func(ctx sdk.Context) { - service, found, err := suite.k.GetService(ctx, 1) + service, err := suite.k.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().False(service.Accredited) }, }, diff --git a/x/services/keeper/services.go b/x/services/keeper/services.go index 15359333..fd8d8cb8 100644 --- a/x/services/keeper/services.go +++ b/x/services/keeper/services.go @@ -64,15 +64,14 @@ func (k *Keeper) CreateService(ctx context.Context, service types.Service) error // ActivateService activates the service with the given ID func (k *Keeper) ActivateService(ctx context.Context, serviceID uint32) error { - service, found, err := k.GetService(ctx, serviceID) + service, err := k.GetService(ctx, serviceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return types.ErrServiceNotFound + } return err } - if !found { - return types.ErrServiceNotFound - } - // Check if the service is already active if service.Status == types.SERVICE_STATUS_ACTIVE { return types.ErrServiceAlreadyActive @@ -93,15 +92,14 @@ func (k *Keeper) ActivateService(ctx context.Context, serviceID uint32) error { // DeactivateService deactivates the service with the given ID func (k *Keeper) DeactivateService(ctx context.Context, serviceID uint32) error { - service, exists, err := k.GetService(ctx, serviceID) + service, err := k.GetService(ctx, serviceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return types.ErrServiceNotFound + } return err } - if !exists { - return types.ErrServiceNotFound - } - // Make sure the service is active if service.Status != types.SERVICE_STATUS_ACTIVE { return types.ErrServiceNotActive @@ -121,15 +119,14 @@ func (k *Keeper) DeactivateService(ctx context.Context, serviceID uint32) error // DeleteService deletes the service with the given ID func (k *Keeper) DeleteService(ctx context.Context, serviceID uint32) error { - service, exists, err := k.GetService(ctx, serviceID) + service, err := k.GetService(ctx, serviceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return types.ErrServiceNotFound + } return err } - if !exists { - return types.ErrServiceNotFound - } - // Make sure the service is not active if service.Status == types.SERVICE_STATUS_ACTIVE { return types.ErrServiceIsActive @@ -153,15 +150,14 @@ func (k *Keeper) DeleteService(ctx context.Context, serviceID uint32) error { // SetServiceAccredited sets the accreditation of the service with the given ID func (k *Keeper) SetServiceAccredited(ctx context.Context, serviceID uint32, accredited bool) error { // Check if the service exists - service, found, err := k.GetService(ctx, serviceID) + service, err := k.GetService(ctx, serviceID) if err != nil { + if errors.IsOf(err, collections.ErrNotFound) { + return types.ErrServiceNotFound + } return err } - if !found { - return errors.Wrapf(types.ErrServiceNotFound, "service with id %d not found", serviceID) - } - // Skip any operation if the service accreditation status does not change if service.Accredited == accredited { return nil @@ -184,15 +180,8 @@ func (k *Keeper) HasService(ctx context.Context, serviceID uint32) (bool, error) } // GetService returns an Service from the KVStore -func (k *Keeper) GetService(ctx context.Context, serviceID uint32) (service types.Service, found bool, err error) { - service, err = k.services.Get(ctx, serviceID) - if err != nil { - if errors.IsOf(err, collections.ErrNotFound) { - return service, false, nil - } - return service, false, err - } - return service, true, nil +func (k *Keeper) GetService(ctx context.Context, serviceID uint32) (service types.Service, err error) { + return k.services.Get(ctx, serviceID) } // GetServiceParams returns the params for the service with the given ID diff --git a/x/services/keeper/services_test.go b/x/services/keeper/services_test.go index 1bd95d58..a961c8b4 100644 --- a/x/services/keeper/services_test.go +++ b/x/services/keeper/services_test.go @@ -1,6 +1,7 @@ package keeper_test import ( + "cosmossdk.io/collections" sdkmath "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" @@ -145,9 +146,8 @@ func (suite *KeeperTestSuite) TestKeeper_CreateService() { suite.Require().True(hasAccount) // Make sure the service has been created - service, found, err := suite.k.GetService(ctx, 1) + service, err := suite.k.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewService( 1, types.SERVICE_STATUS_ACTIVE, @@ -240,9 +240,8 @@ func (suite *KeeperTestSuite) TestKeeper_ActivateService() { serviceID: 1, shouldErr: false, check: func(ctx sdk.Context) { - service, found, err := suite.k.GetService(ctx, 1) + service, err := suite.k.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewService( 1, types.SERVICE_STATUS_ACTIVE, @@ -335,9 +334,8 @@ func (suite *KeeperTestSuite) TestKeeper_DeactivateService() { serviceID: 1, shouldErr: false, check: func(ctx sdk.Context) { - service, found, err := suite.k.GetService(ctx, 1) + service, err := suite.k.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().Equal(types.NewService( 1, types.SERVICE_STATUS_INACTIVE, @@ -415,9 +413,8 @@ func (suite *KeeperTestSuite) TestKeeper_SetServiceAccreditation() { shouldErr: false, check: func(ctx sdk.Context) { // Accreditation didn't change - service, found, err := suite.k.GetService(ctx, 1) + service, err := suite.k.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().False(service.Accredited) // Make sure the hook wasn't called @@ -444,9 +441,8 @@ func (suite *KeeperTestSuite) TestKeeper_SetServiceAccreditation() { shouldErr: false, check: func(ctx sdk.Context) { // Accreditation changed - service, found, err := suite.k.GetService(ctx, 1) + service, err := suite.k.GetService(ctx, 1) suite.Require().NoError(err) - suite.Require().True(found) suite.Require().True(service.Accredited) // Make sure the hook was called @@ -485,14 +481,12 @@ func (suite *KeeperTestSuite) TestKeeper_GetService() { name string store func(ctx sdk.Context) serviceID uint32 - shouldErr bool expFound bool expService types.Service }{ { name: "service not found returns false", serviceID: 1, - shouldErr: false, expFound: false, }, { @@ -511,7 +505,6 @@ func (suite *KeeperTestSuite) TestKeeper_GetService() { suite.Require().NoError(err) }, serviceID: 1, - shouldErr: false, expFound: true, expService: types.NewService( 1, @@ -534,18 +527,12 @@ func (suite *KeeperTestSuite) TestKeeper_GetService() { tc.store(ctx) } - service, found, err := suite.k.GetService(ctx, tc.serviceID) - if tc.shouldErr { - suite.Require().Error(err) + service, err := suite.k.GetService(ctx, tc.serviceID) + if !tc.expFound { + suite.Require().ErrorIs(err, collections.ErrNotFound) } else { suite.Require().NoError(err) - - if !tc.expFound { - suite.Require().False(found) - } else { - suite.Require().True(found) - suite.Require().Equal(tc.expService, service) - } + suite.Require().Equal(tc.expService, service) } }) }