From 8b758655ab61fac7d43545e23488ab3b8c01ca2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rados=C5=82aw=20Kapka?= Date: Tue, 1 Jun 2021 08:28:00 +0200 Subject: [PATCH] Return correct status codes from beacon endpoints (#8960) --- beacon-chain/rpc/beaconv1/BUILD.bazel | 4 +- beacon-chain/rpc/beaconv1/blocks.go | 41 +- beacon-chain/rpc/beaconv1/pool.go | 8 +- beacon-chain/rpc/beaconv1/server.go | 2 +- beacon-chain/rpc/beaconv1/state.go | 145 +--- beacon-chain/rpc/beaconv1/state_test.go | 653 ++---------------- beacon-chain/rpc/beaconv1/validator.go | 109 ++- beacon-chain/rpc/beaconv1/validator_test.go | 194 ++---- beacon-chain/rpc/debugv1/BUILD.bazel | 1 - beacon-chain/rpc/debugv1/debug.go | 20 +- beacon-chain/rpc/nodev1/node.go | 8 +- beacon-chain/rpc/nodev1/node_test.go | 6 +- beacon-chain/rpc/service.go | 2 +- beacon-chain/rpc/statefetcher/BUILD.bazel | 1 + beacon-chain/rpc/statefetcher/fetcher.go | 174 ++++- beacon-chain/rpc/statefetcher/fetcher_test.go | 189 ++++- .../rpc/testutil/mock_state_fetcher.go | 8 +- .../state/stateV0/getters_validator.go | 24 +- 18 files changed, 664 insertions(+), 925 deletions(-) diff --git a/beacon-chain/rpc/beaconv1/BUILD.bazel b/beacon-chain/rpc/beaconv1/BUILD.bazel index d00a03ffbebe..6bfdc63d2265 100644 --- a/beacon-chain/rpc/beaconv1/BUILD.bazel +++ b/beacon-chain/rpc/beaconv1/BUILD.bazel @@ -28,6 +28,7 @@ go_library( "//beacon-chain/p2p:go_default_library", "//beacon-chain/rpc/statefetcher:go_default_library", "//beacon-chain/state/interface:go_default_library", + "//beacon-chain/state/stateV0:go_default_library", "//beacon-chain/state/stategen:go_default_library", "//proto/migration:go_default_library", "//shared/bytesutil:go_default_library", @@ -71,8 +72,8 @@ go_test( "//beacon-chain/operations/voluntaryexits:go_default_library", "//beacon-chain/p2p/testing:go_default_library", "//beacon-chain/rpc/statefetcher:go_default_library", + "//beacon-chain/rpc/testutil:go_default_library", "//beacon-chain/state/interface:go_default_library", - "//beacon-chain/state/stategen:go_default_library", "//proto/beacon/p2p/v1:go_default_library", "//proto/migration:go_default_library", "//shared/bls:go_default_library", @@ -82,7 +83,6 @@ go_test( "//shared/testutil:go_default_library", "//shared/testutil/assert:go_default_library", "//shared/testutil/require:go_default_library", - "@com_github_ethereum_go_ethereum//common/hexutil:go_default_library", "@com_github_grpc_ecosystem_grpc_gateway_v2//runtime:go_default_library", "@com_github_prysmaticlabs_eth2_types//:go_default_library", "@com_github_prysmaticlabs_ethereumapis//eth/v1:go_default_library", diff --git a/beacon-chain/rpc/beaconv1/blocks.go b/beacon-chain/rpc/beaconv1/blocks.go index ccc1e217796f..6039324f79cc 100644 --- a/beacon-chain/rpc/beaconv1/blocks.go +++ b/beacon-chain/rpc/beaconv1/blocks.go @@ -20,12 +20,32 @@ import ( "google.golang.org/protobuf/types/known/emptypb" ) +// blockIdParseError represents an error scenario where a block ID could not be parsed. +type blockIdParseError struct { + message string +} + +// newBlockIdParseError creates a new error instance. +func newBlockIdParseError(reason error) blockIdParseError { + return blockIdParseError{ + message: fmt.Sprintf("could not parse block ID: %v", reason), + } +} + +// Error returns the underlying error message. +func (e *blockIdParseError) Error() string { + return e.message +} + // GetBlockHeader retrieves block header for given block id. func (bs *Server) GetBlockHeader(ctx context.Context, req *ethpb.BlockRequest) (*ethpb.BlockHeaderResponse, error) { ctx, span := trace.StartSpan(ctx, "beaconv1.GetBlockHeader") defer span.End() rBlk, err := bs.blockFromBlockID(ctx, req.BlockId) + if invalidBlockIdErr, ok := err.(*blockIdParseError); ok { + return nil, status.Errorf(codes.InvalidArgument, "Invalid block ID: %v", invalidBlockIdErr) + } if err != nil { return nil, status.Errorf(codes.Internal, "Could not get block from block ID: %v", err) } @@ -141,13 +161,13 @@ func (bs *Server) SubmitBlock(ctx context.Context, req *ethpb.BeaconBlockContain blk := req.Message rBlock, err := migration.V1ToV1Alpha1Block(ðpb.SignedBeaconBlock{Block: blk, Signature: req.Signature}) if err != nil { - return nil, status.Errorf(codes.Internal, "Could not convert block to v1") + return nil, status.Errorf(codes.InvalidArgument, "Could not convert block to v1 block") } v1alpha1Block := interfaces.WrappedPhase0SignedBeaconBlock(rBlock) root, err := blk.HashTreeRoot() if err != nil { - return nil, status.Errorf(codes.Internal, "Could not tree hash block: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "Could not tree hash block: %v", err) } // Do not block proposal critical path with debug logging or block feed updates. @@ -172,12 +192,15 @@ func (bs *Server) SubmitBlock(ctx context.Context, req *ethpb.BeaconBlockContain return &emptypb.Empty{}, nil } -// GetBlock retrieves block details for given block id. +// GetBlock retrieves block details for given block ID. func (bs *Server) GetBlock(ctx context.Context, req *ethpb.BlockRequest) (*ethpb.BlockResponse, error) { ctx, span := trace.StartSpan(ctx, "beaconv1.GetBlock") defer span.End() rBlk, err := bs.blockFromBlockID(ctx, req.BlockId) + if invalidBlockIdErr, ok := err.(*blockIdParseError); ok { + return nil, status.Errorf(codes.InvalidArgument, "Invalid block ID: %v", invalidBlockIdErr) + } if err != nil { return nil, status.Errorf(codes.Internal, "Could not get block from block ID: %v", err) } @@ -191,7 +214,7 @@ func (bs *Server) GetBlock(ctx context.Context, req *ethpb.BlockRequest) (*ethpb v1Block, err := migration.V1Alpha1ToV1Block(blk) if err != nil { - return nil, status.Errorf(codes.Internal, "Could not convert block to v1") + return nil, status.Errorf(codes.Internal, "Could not convert block to v1 block") } return ðpb.BlockResponse{ @@ -248,7 +271,7 @@ func (bs *Server) GetBlockRoot(ctx context.Context, req *ethpb.BlockRequest) (*e } else { slot, err := strconv.ParseUint(string(req.BlockId), 10, 64) if err != nil { - return nil, status.Errorf(codes.Internal, "could not decode block id: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "Could not parse block ID: %v", err) } hasRoots, roots, err := bs.BeaconDB.BlockRootsBySlot(ctx, types.Slot(slot)) if err != nil { @@ -288,6 +311,9 @@ func (bs *Server) ListBlockAttestations(ctx context.Context, req *ethpb.BlockReq defer span.End() rBlk, err := bs.blockFromBlockID(ctx, req.BlockId) + if invalidBlockIdErr, ok := err.(*blockIdParseError); ok { + return nil, status.Errorf(codes.InvalidArgument, "Invalid block ID: %v", invalidBlockIdErr) + } if err != nil { return nil, status.Errorf(codes.Internal, "Could not get block from block ID: %v", err) } @@ -302,7 +328,7 @@ func (bs *Server) ListBlockAttestations(ctx context.Context, req *ethpb.BlockReq v1Block, err := migration.V1Alpha1ToV1Block(blk) if err != nil { - return nil, status.Errorf(codes.Internal, "Could not convert block to v1") + return nil, status.Errorf(codes.Internal, "Could not convert block to v1 block") } return ðpb.BlockAttestationsResponse{ Data: v1Block.Block.Body.Attestations, @@ -339,7 +365,8 @@ func (bs *Server) blockFromBlockID(ctx context.Context, blockId []byte) (interfa } else { slot, err := strconv.ParseUint(string(blockId), 10, 64) if err != nil { - return nil, errors.Wrap(err, "could not decode block id") + e := newBlockIdParseError(err) + return nil, &e } _, blks, err := bs.BeaconDB.BlocksBySlot(ctx, types.Slot(slot)) if err != nil { diff --git a/beacon-chain/rpc/beaconv1/pool.go b/beacon-chain/rpc/beaconv1/pool.go index 91bb312e1833..97941defe0a4 100644 --- a/beacon-chain/rpc/beaconv1/pool.go +++ b/beacon-chain/rpc/beaconv1/pool.go @@ -118,7 +118,7 @@ func (bs *Server) SubmitAttestations(ctx context.Context, req *ethpb.SubmitAttes if err != nil { return nil, status.Errorf(codes.Internal, "Could not prepare attestation failure information: %v", err) } - return nil, status.Errorf(codes.Internal, "One or more attestations failed validation") + return nil, status.Errorf(codes.InvalidArgument, "One or more attestations failed validation") } return &emptypb.Empty{}, nil @@ -160,7 +160,7 @@ func (bs *Server) SubmitAttesterSlashing(ctx context.Context, req *ethpb.Atteste alphaSlashing := migration.V1AttSlashingToV1Alpha1(req) err = blocks.VerifyAttesterSlashing(ctx, headState, alphaSlashing) if err != nil { - return nil, status.Errorf(codes.Internal, "Invalid attester slashing: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "Invalid attester slashing: %v", err) } err = bs.SlashingsPool.InsertAttesterSlashing(ctx, headState, alphaSlashing) @@ -212,7 +212,7 @@ func (bs *Server) SubmitProposerSlashing(ctx context.Context, req *ethpb.Propose alphaSlashing := migration.V1ProposerSlashingToV1Alpha1(req) err = blocks.VerifyProposerSlashing(headState, alphaSlashing) if err != nil { - return nil, status.Errorf(codes.Internal, "Invalid proposer slashing: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "Invalid proposer slashing: %v", err) } err = bs.SlashingsPool.InsertProposerSlashing(ctx, headState, alphaSlashing) @@ -269,7 +269,7 @@ func (bs *Server) SubmitVoluntaryExit(ctx context.Context, req *ethpb.SignedVolu alphaExit := migration.V1ExitToV1Alpha1(req) err = blocks.VerifyExitAndSignature(validator, headState.Slot(), headState.Fork(), alphaExit, headState.GenesisValidatorRoot()) if err != nil { - return nil, status.Errorf(codes.Internal, "Invalid voluntary exit: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "Invalid voluntary exit: %v", err) } bs.VoluntaryExitsPool.InsertVoluntaryExit(ctx, headState, alphaExit) diff --git a/beacon-chain/rpc/beaconv1/server.go b/beacon-chain/rpc/beaconv1/server.go index fa1bdeb71d26..5f7877181494 100644 --- a/beacon-chain/rpc/beaconv1/server.go +++ b/beacon-chain/rpc/beaconv1/server.go @@ -29,5 +29,5 @@ type Server struct { SlashingsPool slashings.PoolManager VoluntaryExitsPool voluntaryexits.PoolManager StateGenService stategen.StateManager - StateFetcher statefetcher.StateProvider + StateFetcher statefetcher.Fetcher } diff --git a/beacon-chain/rpc/beaconv1/state.go b/beacon-chain/rpc/beaconv1/state.go index 2d4bce020b8a..de54d1aa3ee6 100644 --- a/beacon-chain/rpc/beaconv1/state.go +++ b/beacon-chain/rpc/beaconv1/state.go @@ -3,17 +3,11 @@ package beaconv1 import ( "bytes" "context" - "fmt" - "strconv" - "strings" - "github.com/pkg/errors" - types "github.com/prysmaticlabs/eth2-types" ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1" eth "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1" - "github.com/prysmaticlabs/prysm/beacon-chain/core/helpers" + "github.com/prysmaticlabs/prysm/beacon-chain/rpc/statefetcher" iface "github.com/prysmaticlabs/prysm/beacon-chain/state/interface" - "github.com/prysmaticlabs/prysm/shared/bytesutil" "github.com/prysmaticlabs/prysm/shared/params" "go.opencensus.io/trace" "google.golang.org/grpc/codes" @@ -59,8 +53,13 @@ func (bs *Server) GetStateRoot(ctx context.Context, req *ethpb.StateRequest) (*e err error ) - root, err = bs.stateRoot(ctx, req.StateId) + root, err = bs.StateFetcher.StateRoot(ctx, req.StateId) if err != nil { + if rootNotFoundErr, ok := err.(*statefetcher.StateRootNotFoundError); ok { + return nil, status.Errorf(codes.NotFound, "State root not found: %v", rootNotFoundErr) + } else if parseErr, ok := err.(*statefetcher.StateIdParseError); ok { + return nil, status.Errorf(codes.InvalidArgument, "Invalid state ID: %v", parseErr) + } return nil, status.Errorf(codes.Internal, "Could not get state root: %v", err) } @@ -83,6 +82,11 @@ func (bs *Server) GetStateFork(ctx context.Context, req *ethpb.StateRequest) (*e state, err = bs.StateFetcher.State(ctx, req.StateId) if err != nil { + if stateNotFoundErr, ok := err.(*statefetcher.StateNotFoundError); ok { + return nil, status.Errorf(codes.NotFound, "State not found: %v", stateNotFoundErr) + } else if parseErr, ok := err.(*statefetcher.StateIdParseError); ok { + return nil, status.Errorf(codes.InvalidArgument, "Invalid state ID: %v", parseErr) + } return nil, status.Errorf(codes.Internal, "Could not get state: %v", err) } @@ -109,6 +113,11 @@ func (bs *Server) GetFinalityCheckpoints(ctx context.Context, req *ethpb.StateRe state, err = bs.StateFetcher.State(ctx, req.StateId) if err != nil { + if stateNotFoundErr, ok := err.(*statefetcher.StateNotFoundError); ok { + return nil, status.Errorf(codes.NotFound, "State not found: %v", stateNotFoundErr) + } else if parseErr, ok := err.(*statefetcher.StateIdParseError); ok { + return nil, status.Errorf(codes.InvalidArgument, "Invalid state ID: %v", parseErr) + } return nil, status.Errorf(codes.Internal, "Could not get state: %v", err) } @@ -121,126 +130,6 @@ func (bs *Server) GetFinalityCheckpoints(ctx context.Context, req *ethpb.StateRe }, nil } -func (bs *Server) stateRoot(ctx context.Context, stateId []byte) ([]byte, error) { - var ( - root []byte - err error - ) - - stateIdString := strings.ToLower(string(stateId)) - switch stateIdString { - case "head": - root, err = bs.headStateRoot(ctx) - case "genesis": - root, err = bs.genesisStateRoot(ctx) - case "finalized": - root, err = bs.finalizedStateRoot(ctx) - case "justified": - root, err = bs.justifiedStateRoot(ctx) - default: - if len(stateId) == 32 { - root, err = bs.stateRootByHex(ctx, stateId) - } else { - slotNumber, parseErr := strconv.ParseUint(stateIdString, 10, 64) - if parseErr != nil { - // ID format does not match any valid options. - return nil, errors.New("invalid state ID: " + stateIdString) - } - root, err = bs.stateRootBySlot(ctx, types.Slot(slotNumber)) - } - } - - return root, err -} - -func (bs *Server) headStateRoot(ctx context.Context) ([]byte, error) { - b, err := bs.ChainInfoFetcher.HeadBlock(ctx) - if err != nil { - return nil, errors.Wrap(err, "could not get head block") - } - if err := helpers.VerifyNilBeaconBlock(b); err != nil { - return nil, err - } - return b.Block().StateRoot(), nil -} - -func (bs *Server) genesisStateRoot(ctx context.Context) ([]byte, error) { - b, err := bs.BeaconDB.GenesisBlock(ctx) - if err != nil { - return nil, errors.Wrap(err, "could not get genesis block") - } - if err := helpers.VerifyNilBeaconBlock(b); err != nil { - return nil, err - } - return b.Block().StateRoot(), nil -} - -func (bs *Server) finalizedStateRoot(ctx context.Context) ([]byte, error) { - cp, err := bs.BeaconDB.FinalizedCheckpoint(ctx) - if err != nil { - return nil, errors.Wrap(err, "could not get finalized checkpoint") - } - b, err := bs.BeaconDB.Block(ctx, bytesutil.ToBytes32(cp.Root)) - if err != nil { - return nil, errors.Wrap(err, "could not get finalized block") - } - if err := helpers.VerifyNilBeaconBlock(b); err != nil { - return nil, err - } - return b.Block().StateRoot(), nil -} - -func (bs *Server) justifiedStateRoot(ctx context.Context) ([]byte, error) { - cp, err := bs.BeaconDB.JustifiedCheckpoint(ctx) - if err != nil { - return nil, errors.Wrap(err, "could not get justified checkpoint") - } - b, err := bs.BeaconDB.Block(ctx, bytesutil.ToBytes32(cp.Root)) - if err != nil { - return nil, errors.Wrap(err, "could not get justified block") - } - if err := helpers.VerifyNilBeaconBlock(b); err != nil { - return nil, err - } - return b.Block().StateRoot(), nil -} - -func (bs *Server) stateRootByHex(ctx context.Context, stateId []byte) ([]byte, error) { - var stateRoot [32]byte - copy(stateRoot[:], stateId) - headState, err := bs.ChainInfoFetcher.HeadState(ctx) - if err != nil { - return nil, errors.Wrap(err, "could not get head state") - } - for _, root := range headState.StateRoots() { - if bytes.Equal(root, stateRoot[:]) { - return stateRoot[:], nil - } - } - return nil, fmt.Errorf("state not found in the last %d state roots", len(headState.StateRoots())) -} - -func (bs *Server) stateRootBySlot(ctx context.Context, slot types.Slot) ([]byte, error) { - currentSlot := bs.GenesisTimeFetcher.CurrentSlot() - if slot > currentSlot { - return nil, errors.New("slot cannot be in the future") - } - found, blks, err := bs.BeaconDB.BlocksBySlot(ctx, slot) - if err != nil { - return nil, errors.Wrap(err, "could not get blocks") - } - if !found { - return nil, errors.New("no block exists") - } - if len(blks) != 1 { - return nil, errors.New("multiple blocks exist in same slot") - } - if blks[0] == nil || blks[0].IsNil() || blks[0].Block().IsNil() { - return nil, errors.New("nil block") - } - return blks[0].Block().StateRoot(), nil -} - func checkpoint(sourceCheckpoint *eth.Checkpoint) *ethpb.Checkpoint { if sourceCheckpoint != nil { return ðpb.Checkpoint{ diff --git a/beacon-chain/rpc/beaconv1/state_test.go b/beacon-chain/rpc/beaconv1/state_test.go index f1f2da115153..d373744fa232 100644 --- a/beacon-chain/rpc/beaconv1/state_test.go +++ b/beacon-chain/rpc/beaconv1/state_test.go @@ -2,25 +2,17 @@ package beaconv1 import ( "context" - "fmt" - "strconv" - "strings" "testing" "time" - "github.com/ethereum/go-ethereum/common/hexutil" - types "github.com/prysmaticlabs/eth2-types" ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1" eth "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1" chainMock "github.com/prysmaticlabs/prysm/beacon-chain/blockchain/testing" - testDB "github.com/prysmaticlabs/prysm/beacon-chain/db/testing" - "github.com/prysmaticlabs/prysm/beacon-chain/rpc/statefetcher" - "github.com/prysmaticlabs/prysm/beacon-chain/state/stategen" + "github.com/prysmaticlabs/prysm/beacon-chain/rpc/testutil" pb "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1" "github.com/prysmaticlabs/prysm/shared/bytesutil" - "github.com/prysmaticlabs/prysm/shared/interfaces" "github.com/prysmaticlabs/prysm/shared/params" - "github.com/prysmaticlabs/prysm/shared/testutil" + sharedtestutil "github.com/prysmaticlabs/prysm/shared/testutil" "github.com/prysmaticlabs/prysm/shared/testutil/assert" "github.com/prysmaticlabs/prysm/shared/testutil/require" "google.golang.org/protobuf/types/known/emptypb" @@ -80,188 +72,26 @@ func TestGetGenesis(t *testing.T) { } func TestGetStateRoot(t *testing.T) { - db := testDB.SetupDB(t) ctx := context.Background() + fakeState, err := sharedtestutil.NewBeaconState() + require.NoError(t, err) + stateRoot, err := fakeState.HashTreeRoot(ctx) + require.NoError(t, err) + server := &Server{ + StateFetcher: &testutil.MockFetcher{ + BeaconStateRoot: stateRoot[:], + }, + } - t.Run("Head", func(t *testing.T) { - b := testutil.NewBeaconBlock() - b.Block.StateRoot = bytesutil.PadTo([]byte("head"), 32) - s := Server{ - ChainInfoFetcher: &chainMock.ChainService{Block: interfaces.WrappedPhase0SignedBeaconBlock(b)}, - } - - resp, err := s.GetStateRoot(ctx, ðpb.StateRequest{ - StateId: []byte("head"), - }) - require.NoError(t, err) - assert.DeepEqual(t, bytesutil.PadTo([]byte("head"), 32), resp.Data.Root) - }) - - t.Run("Genesis", func(t *testing.T) { - b := testutil.NewBeaconBlock() - b.Block.StateRoot = bytesutil.PadTo([]byte("genesis"), 32) - require.NoError(t, db.SaveBlock(ctx, interfaces.WrappedPhase0SignedBeaconBlock(b))) - r, err := b.Block.HashTreeRoot() - require.NoError(t, err) - require.NoError(t, db.SaveStateSummary(ctx, &pb.StateSummary{Root: r[:]})) - require.NoError(t, db.SaveGenesisBlockRoot(ctx, r)) - s := Server{ - BeaconDB: db, - } - - resp, err := s.GetStateRoot(ctx, ðpb.StateRequest{ - StateId: []byte("genesis"), - }) - require.NoError(t, err) - assert.DeepEqual(t, bytesutil.PadTo([]byte("genesis"), 32), resp.Data.Root) - }) - - t.Run("Finalized", func(t *testing.T) { - parent := testutil.NewBeaconBlock() - parentR, err := parent.Block.HashTreeRoot() - require.NoError(t, err) - require.NoError(t, db.SaveBlock(ctx, interfaces.WrappedPhase0SignedBeaconBlock(parent))) - require.NoError(t, db.SaveGenesisBlockRoot(ctx, parentR)) - b := testutil.NewBeaconBlock() - b.Block.ParentRoot = parentR[:] - b.Block.StateRoot = bytesutil.PadTo([]byte("finalized"), 32) - require.NoError(t, db.SaveBlock(ctx, interfaces.WrappedPhase0SignedBeaconBlock(b))) - r, err := b.Block.HashTreeRoot() - require.NoError(t, err) - require.NoError(t, db.SaveStateSummary(ctx, &pb.StateSummary{Root: r[:]})) - require.NoError(t, db.SaveFinalizedCheckpoint(ctx, ð.Checkpoint{Root: r[:]})) - s := Server{ - BeaconDB: db, - } - - resp, err := s.GetStateRoot(ctx, ðpb.StateRequest{ - StateId: []byte("finalized"), - }) - require.NoError(t, err) - assert.DeepEqual(t, bytesutil.PadTo([]byte("finalized"), 32), resp.Data.Root) - }) - - t.Run("Justified", func(t *testing.T) { - parent := testutil.NewBeaconBlock() - parentR, err := parent.Block.HashTreeRoot() - require.NoError(t, err) - require.NoError(t, db.SaveBlock(ctx, interfaces.WrappedPhase0SignedBeaconBlock(parent))) - require.NoError(t, db.SaveGenesisBlockRoot(ctx, parentR)) - b := testutil.NewBeaconBlock() - b.Block.ParentRoot = parentR[:] - b.Block.StateRoot = bytesutil.PadTo([]byte("justified"), 32) - require.NoError(t, db.SaveBlock(ctx, interfaces.WrappedPhase0SignedBeaconBlock(b))) - r, err := b.Block.HashTreeRoot() - require.NoError(t, err) - require.NoError(t, db.SaveStateSummary(ctx, &pb.StateSummary{Root: r[:]})) - require.NoError(t, db.SaveJustifiedCheckpoint(ctx, ð.Checkpoint{Root: r[:]})) - s := Server{ - BeaconDB: db, - } - - resp, err := s.GetStateRoot(ctx, ðpb.StateRequest{ - StateId: []byte("justified"), - }) - require.NoError(t, err) - assert.DeepEqual(t, bytesutil.PadTo([]byte("justified"), 32), resp.Data.Root) - }) - - t.Run("Hex root", func(t *testing.T) { - state, err := testutil.NewBeaconState(testutil.FillRootsNaturalOpt) - require.NoError(t, err) - chainService := &chainMock.ChainService{ - State: state, - } - s := Server{ - ChainInfoFetcher: chainService, - } - stateId, err := hexutil.Decode("0x" + strings.Repeat("0", 63) + "1") - require.NoError(t, err) - resp, err := s.GetStateRoot(ctx, ðpb.StateRequest{ - StateId: stateId, - }) - require.NoError(t, err) - assert.DeepEqual(t, stateId, resp.Data.Root) - }) - - t.Run("Hex root not found", func(t *testing.T) { - state, err := testutil.NewBeaconState() - require.NoError(t, err) - chainService := &chainMock.ChainService{ - State: state, - } - s := Server{ - ChainInfoFetcher: chainService, - } - stateId, err := hexutil.Decode("0x" + strings.Repeat("f", 64)) - require.NoError(t, err) - _, err = s.GetStateRoot(ctx, ðpb.StateRequest{ - StateId: stateId, - }) - assert.ErrorContains(t, fmt.Sprintf("state not found in the last %d state roots", len(state.StateRoots())), err) - }) - - t.Run("Slot", func(t *testing.T) { - b := testutil.NewBeaconBlock() - b.Block.Slot = 100 - b.Block.StateRoot = bytesutil.PadTo([]byte("slot"), 32) - require.NoError(t, db.SaveBlock(ctx, interfaces.WrappedPhase0SignedBeaconBlock(b))) - s := Server{ - BeaconDB: db, - GenesisTimeFetcher: &chainMock.ChainService{}, - } - - resp, err := s.GetStateRoot(ctx, ðpb.StateRequest{ - StateId: []byte("100"), - }) - require.NoError(t, err) - assert.DeepEqual(t, bytesutil.PadTo([]byte("slot"), 32), resp.Data.Root) - }) - - t.Run("Multiple slots", func(t *testing.T) { - b := testutil.NewBeaconBlock() - b.Block.Slot = 100 - b.Block.StateRoot = bytesutil.PadTo([]byte("slot"), 32) - require.NoError(t, db.SaveBlock(ctx, interfaces.WrappedPhase0SignedBeaconBlock(b))) - b = testutil.NewBeaconBlock() - b.Block.Slot = 100 - b.Block.StateRoot = bytesutil.PadTo([]byte("sLot"), 32) - require.NoError(t, db.SaveBlock(ctx, interfaces.WrappedPhase0SignedBeaconBlock(b))) - s := Server{ - BeaconDB: db, - GenesisTimeFetcher: &chainMock.ChainService{}, - } - - _, err := s.GetStateRoot(ctx, ðpb.StateRequest{ - StateId: []byte("100"), - }) - assert.ErrorContains(t, "multiple blocks exist in same slot", err) - }) - - t.Run("Slot too big", func(t *testing.T) { - s := Server{ - GenesisTimeFetcher: &chainMock.ChainService{ - Genesis: time.Now(), - }, - } - _, err := s.GetStateRoot(ctx, ðpb.StateRequest{ - StateId: []byte(strconv.FormatUint(1, 10)), - }) - assert.ErrorContains(t, "slot cannot be in the future", err) - }) - - t.Run("Invalid state", func(t *testing.T) { - s := Server{} - _, err := s.GetStateRoot(ctx, ðpb.StateRequest{ - StateId: []byte("foo"), - }) - require.ErrorContains(t, "invalid state ID: foo", err) + resp, err := server.GetStateRoot(context.Background(), ðpb.StateRequest{ + StateId: make([]byte, 0), }) + require.NoError(t, err) + assert.NotNil(t, resp) + assert.DeepEqual(t, stateRoot[:], resp.Data.Root) } func TestGetStateFork(t *testing.T) { - ctx := context.Background() - fillFork := func(state *pb.BeaconState) error { state.Fork = &pb.Fork{ PreviousVersion: []byte("prev"), @@ -270,197 +100,26 @@ func TestGetStateFork(t *testing.T) { } return nil } - headSlot := types.Slot(123) - fillSlot := func(state *pb.BeaconState) error { - state.Slot = headSlot - return nil - } - state, err := testutil.NewBeaconState(testutil.FillRootsNaturalOpt, fillFork, fillSlot) - require.NoError(t, err) - stateRoot, err := state.HashTreeRoot(ctx) + fakeState, err := sharedtestutil.NewBeaconState(fillFork) require.NoError(t, err) + server := &Server{ + StateFetcher: &testutil.MockFetcher{ + BeaconState: fakeState, + }, + } - t.Run("Head", func(t *testing.T) { - s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, - }, - } - - resp, err := s.GetStateFork(ctx, ðpb.StateRequest{ - StateId: []byte("head"), - }) - require.NoError(t, err) - assert.DeepEqual(t, []byte("prev"), resp.Data.PreviousVersion) - assert.DeepEqual(t, []byte("curr"), resp.Data.CurrentVersion) - assert.Equal(t, types.Epoch(123), resp.Data.Epoch) - }) - - t.Run("Genesis", func(t *testing.T) { - db := testDB.SetupDB(t) - b := testutil.NewBeaconBlock() - b.Block.StateRoot = bytesutil.PadTo([]byte("genesis"), 32) - require.NoError(t, db.SaveBlock(ctx, interfaces.WrappedPhase0SignedBeaconBlock(b))) - r, err := b.Block.HashTreeRoot() - require.NoError(t, err) - require.NoError(t, db.SaveStateSummary(ctx, &pb.StateSummary{Root: r[:]})) - require.NoError(t, db.SaveGenesisBlockRoot(ctx, r)) - st, err := testutil.NewBeaconState(func(state *pb.BeaconState) error { - state.Fork = &pb.Fork{ - PreviousVersion: []byte("prev"), - CurrentVersion: []byte("curr"), - Epoch: 123, - } - return nil - }) - require.NoError(t, err) - require.NoError(t, db.SaveState(ctx, st, r)) - - s := Server{ - StateFetcher: statefetcher.StateProvider{ - BeaconDB: db, - }, - } - - resp, err := s.GetStateFork(ctx, ðpb.StateRequest{ - StateId: []byte("genesis"), - }) - require.NoError(t, err) - assert.DeepEqual(t, []byte("prev"), resp.Data.PreviousVersion) - assert.DeepEqual(t, []byte("curr"), resp.Data.CurrentVersion) - assert.Equal(t, types.Epoch(123), resp.Data.Epoch) - }) - - t.Run("Finalized", func(t *testing.T) { - stateGen := stategen.NewMockService() - stateGen.StatesByRoot[stateRoot] = state - - s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{ - FinalizedCheckPoint: ð.Checkpoint{ - Root: stateRoot[:], - }, - }, - StateGenService: stateGen, - }, - } - - resp, err := s.GetStateFork(ctx, ðpb.StateRequest{ - StateId: []byte("finalized"), - }) - require.NoError(t, err) - assert.DeepEqual(t, []byte("prev"), resp.Data.PreviousVersion) - assert.DeepEqual(t, []byte("curr"), resp.Data.CurrentVersion) - assert.Equal(t, types.Epoch(123), resp.Data.Epoch) - }) - - t.Run("Justified", func(t *testing.T) { - stateGen := stategen.NewMockService() - stateGen.StatesByRoot[stateRoot] = state - - s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{ - CurrentJustifiedCheckPoint: ð.Checkpoint{ - Root: stateRoot[:], - }, - }, - StateGenService: stateGen, - }, - } - - resp, err := s.GetStateFork(ctx, ðpb.StateRequest{ - StateId: []byte("justified"), - }) - require.NoError(t, err) - assert.DeepEqual(t, []byte("prev"), resp.Data.PreviousVersion) - assert.DeepEqual(t, []byte("curr"), resp.Data.CurrentVersion) - assert.Equal(t, types.Epoch(123), resp.Data.Epoch) - }) - - t.Run("Hex root", func(t *testing.T) { - stateId, err := hexutil.Decode("0x" + strings.Repeat("0", 63) + "1") - require.NoError(t, err) - stateGen := stategen.NewMockService() - stateGen.StatesByRoot[bytesutil.ToBytes32(stateId)] = state - - s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, - StateGenService: stateGen, - }, - } - - resp, err := s.GetStateFork(ctx, ðpb.StateRequest{ - StateId: stateId, - }) - require.NoError(t, err) - assert.DeepEqual(t, []byte("prev"), resp.Data.PreviousVersion) - assert.DeepEqual(t, []byte("curr"), resp.Data.CurrentVersion) - assert.Equal(t, types.Epoch(123), resp.Data.Epoch) - }) - - t.Run("Hex root not found", func(t *testing.T) { - s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, - }, - } - stateId, err := hexutil.Decode("0x" + strings.Repeat("f", 64)) - require.NoError(t, err) - _, err = s.GetStateFork(ctx, ðpb.StateRequest{ - StateId: stateId, - }) - require.ErrorContains(t, "state not found in the last 8192 state roots", err) - }) - - t.Run("Slot", func(t *testing.T) { - stateGen := stategen.NewMockService() - stateGen.StatesBySlot[headSlot] = state - - s := Server{ - StateFetcher: statefetcher.StateProvider{ - GenesisTimeFetcher: &chainMock.ChainService{Slot: &headSlot}, - StateGenService: stateGen, - }, - } - - resp, err := s.GetStateFork(ctx, ðpb.StateRequest{ - StateId: []byte(strconv.FormatUint(uint64(headSlot), 10)), - }) - require.NoError(t, err) - assert.DeepEqual(t, []byte("prev"), resp.Data.PreviousVersion) - assert.DeepEqual(t, []byte("curr"), resp.Data.CurrentVersion) - assert.Equal(t, types.Epoch(123), resp.Data.Epoch) - }) - - t.Run("Slot too big", func(t *testing.T) { - s := Server{ - StateFetcher: statefetcher.StateProvider{ - GenesisTimeFetcher: &chainMock.ChainService{ - Genesis: time.Now(), - }, - }, - } - _, err := s.GetStateFork(ctx, ðpb.StateRequest{ - StateId: []byte(strconv.FormatUint(1, 10)), - }) - assert.ErrorContains(t, "slot cannot be in the future", err) - }) - - t.Run("Invalid state", func(t *testing.T) { - s := Server{} - _, err := s.GetStateFork(ctx, ðpb.StateRequest{ - StateId: []byte("foo"), - }) - require.ErrorContains(t, "invalid state ID", err) + resp, err := server.GetStateFork(context.Background(), ðpb.StateRequest{ + StateId: make([]byte, 0), }) + require.NoError(t, err) + assert.NotNil(t, resp) + expectedFork := fakeState.Fork() + assert.Equal(t, expectedFork.Epoch, resp.Data.Epoch) + assert.DeepEqual(t, expectedFork.CurrentVersion, resp.Data.CurrentVersion) + assert.DeepEqual(t, expectedFork.PreviousVersion, resp.Data.PreviousVersion) } func TestGetFinalityCheckpoints(t *testing.T) { - ctx := context.Background() - fillCheckpoints := func(state *pb.BeaconState) error { state.PreviousJustifiedCheckpoint = ð.Checkpoint{ Root: bytesutil.PadTo([]byte("previous"), 32), @@ -476,243 +135,23 @@ func TestGetFinalityCheckpoints(t *testing.T) { } return nil } - headSlot := types.Slot(123) - fillSlot := func(state *pb.BeaconState) error { - state.Slot = headSlot - return nil - } - state, err := testutil.NewBeaconState(testutil.FillRootsNaturalOpt, fillCheckpoints, fillSlot) + fakeState, err := sharedtestutil.NewBeaconState(fillCheckpoints) require.NoError(t, err) - stateRoot, err := state.HashTreeRoot(ctx) - require.NoError(t, err) - - t.Run("Head", func(t *testing.T) { - s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, - }, - } - - resp, err := s.GetFinalityCheckpoints(ctx, ðpb.StateRequest{ - StateId: []byte("head"), - }) - require.NoError(t, err) - assert.DeepEqual(t, bytesutil.PadTo([]byte("previous"), 32), resp.Data.PreviousJustified.Root) - assert.Equal(t, types.Epoch(113), resp.Data.PreviousJustified.Epoch) - assert.DeepEqual(t, bytesutil.PadTo([]byte("current"), 32), resp.Data.CurrentJustified.Root) - assert.Equal(t, types.Epoch(123), resp.Data.CurrentJustified.Epoch) - assert.DeepEqual(t, bytesutil.PadTo([]byte("finalized"), 32), resp.Data.Finalized.Root) - assert.Equal(t, types.Epoch(103), resp.Data.Finalized.Epoch) - }) - - t.Run("Genesis", func(t *testing.T) { - db := testDB.SetupDB(t) - b := testutil.NewBeaconBlock() - b.Block.StateRoot = bytesutil.PadTo([]byte("genesis"), 32) - require.NoError(t, db.SaveBlock(ctx, interfaces.WrappedPhase0SignedBeaconBlock(b))) - r, err := b.Block.HashTreeRoot() - require.NoError(t, err) - require.NoError(t, db.SaveStateSummary(ctx, &pb.StateSummary{Root: r[:]})) - require.NoError(t, db.SaveGenesisBlockRoot(ctx, r)) - st, err := testutil.NewBeaconState(func(state *pb.BeaconState) error { - state.PreviousJustifiedCheckpoint = ð.Checkpoint{ - Root: bytesutil.PadTo([]byte("previous"), 32), - Epoch: 113, - } - state.CurrentJustifiedCheckpoint = ð.Checkpoint{ - Root: bytesutil.PadTo([]byte("current"), 32), - Epoch: 123, - } - state.FinalizedCheckpoint = ð.Checkpoint{ - Root: bytesutil.PadTo([]byte("finalized"), 32), - Epoch: 103, - } - return nil - }) - require.NoError(t, err) - require.NoError(t, db.SaveState(ctx, st, r)) - - s := Server{ - StateFetcher: statefetcher.StateProvider{ - BeaconDB: db, - }, - } - - resp, err := s.GetFinalityCheckpoints(ctx, ðpb.StateRequest{ - StateId: []byte("genesis"), - }) - require.NoError(t, err) - assert.DeepEqual(t, bytesutil.PadTo([]byte("previous"), 32), resp.Data.PreviousJustified.Root) - assert.Equal(t, types.Epoch(113), resp.Data.PreviousJustified.Epoch) - assert.DeepEqual(t, bytesutil.PadTo([]byte("current"), 32), resp.Data.CurrentJustified.Root) - assert.Equal(t, types.Epoch(123), resp.Data.CurrentJustified.Epoch) - assert.DeepEqual(t, bytesutil.PadTo([]byte("finalized"), 32), resp.Data.Finalized.Root) - assert.Equal(t, types.Epoch(103), resp.Data.Finalized.Epoch) - }) - - t.Run("Finalized", func(t *testing.T) { - stateGen := stategen.NewMockService() - stateGen.StatesByRoot[stateRoot] = state - - s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{ - FinalizedCheckPoint: ð.Checkpoint{ - Root: stateRoot[:], - }, - }, - StateGenService: stateGen, - }, - } - - resp, err := s.GetFinalityCheckpoints(ctx, ðpb.StateRequest{ - StateId: []byte("finalized"), - }) - require.NoError(t, err) - assert.DeepEqual(t, bytesutil.PadTo([]byte("previous"), 32), resp.Data.PreviousJustified.Root) - assert.Equal(t, types.Epoch(113), resp.Data.PreviousJustified.Epoch) - assert.DeepEqual(t, bytesutil.PadTo([]byte("current"), 32), resp.Data.CurrentJustified.Root) - assert.Equal(t, types.Epoch(123), resp.Data.CurrentJustified.Epoch) - assert.DeepEqual(t, bytesutil.PadTo([]byte("finalized"), 32), resp.Data.Finalized.Root) - assert.Equal(t, types.Epoch(103), resp.Data.Finalized.Epoch) - }) - - t.Run("Justified", func(t *testing.T) { - stateGen := stategen.NewMockService() - stateGen.StatesByRoot[stateRoot] = state - - s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{ - CurrentJustifiedCheckPoint: ð.Checkpoint{ - Root: stateRoot[:], - }, - }, - StateGenService: stateGen, - }, - } - - resp, err := s.GetFinalityCheckpoints(ctx, ðpb.StateRequest{ - StateId: []byte("justified"), - }) - require.NoError(t, err) - assert.DeepEqual(t, bytesutil.PadTo([]byte("previous"), 32), resp.Data.PreviousJustified.Root) - assert.Equal(t, types.Epoch(113), resp.Data.PreviousJustified.Epoch) - assert.DeepEqual(t, bytesutil.PadTo([]byte("current"), 32), resp.Data.CurrentJustified.Root) - assert.Equal(t, types.Epoch(123), resp.Data.CurrentJustified.Epoch) - assert.DeepEqual(t, bytesutil.PadTo([]byte("finalized"), 32), resp.Data.Finalized.Root) - assert.Equal(t, types.Epoch(103), resp.Data.Finalized.Epoch) - }) - - t.Run("Hex root", func(t *testing.T) { - stateId, err := hexutil.Decode("0x" + strings.Repeat("0", 63) + "1") - require.NoError(t, err) - stateGen := stategen.NewMockService() - stateGen.StatesByRoot[bytesutil.ToBytes32(stateId)] = state - - s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, - StateGenService: stateGen, - }, - } - - resp, err := s.GetFinalityCheckpoints(ctx, ðpb.StateRequest{ - StateId: stateId, - }) - require.NoError(t, err) - assert.DeepEqual(t, bytesutil.PadTo([]byte("previous"), 32), resp.Data.PreviousJustified.Root) - assert.Equal(t, types.Epoch(113), resp.Data.PreviousJustified.Epoch) - assert.DeepEqual(t, bytesutil.PadTo([]byte("current"), 32), resp.Data.CurrentJustified.Root) - assert.Equal(t, types.Epoch(123), resp.Data.CurrentJustified.Epoch) - assert.DeepEqual(t, bytesutil.PadTo([]byte("finalized"), 32), resp.Data.Finalized.Root) - assert.Equal(t, types.Epoch(103), resp.Data.Finalized.Epoch) - }) - - t.Run("Hex root not found", func(t *testing.T) { - s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, - }, - } - stateId, err := hexutil.Decode("0x" + strings.Repeat("f", 64)) - require.NoError(t, err) - _, err = s.GetFinalityCheckpoints(ctx, ðpb.StateRequest{ - StateId: stateId, - }) - require.ErrorContains(t, "state not found in the last 8192 state roots", err) - }) - - t.Run("Slot", func(t *testing.T) { - stateGen := stategen.NewMockService() - stateGen.StatesBySlot[headSlot] = state - - s := Server{ - StateFetcher: statefetcher.StateProvider{ - GenesisTimeFetcher: &chainMock.ChainService{Slot: &headSlot}, - StateGenService: stateGen, - }, - } - - resp, err := s.GetFinalityCheckpoints(ctx, ðpb.StateRequest{ - StateId: []byte(strconv.FormatUint(uint64(headSlot), 10)), - }) - require.NoError(t, err) - assert.DeepEqual(t, bytesutil.PadTo([]byte("previous"), 32), resp.Data.PreviousJustified.Root) - assert.Equal(t, types.Epoch(113), resp.Data.PreviousJustified.Epoch) - assert.DeepEqual(t, bytesutil.PadTo([]byte("current"), 32), resp.Data.CurrentJustified.Root) - assert.Equal(t, types.Epoch(123), resp.Data.CurrentJustified.Epoch) - assert.DeepEqual(t, bytesutil.PadTo([]byte("finalized"), 32), resp.Data.Finalized.Root) - assert.Equal(t, types.Epoch(103), resp.Data.Finalized.Epoch) - }) - - t.Run("Slot too big", func(t *testing.T) { - s := Server{ - StateFetcher: statefetcher.StateProvider{ - GenesisTimeFetcher: &chainMock.ChainService{ - Genesis: time.Now(), - }, - }, - } - _, err := s.GetFinalityCheckpoints(ctx, ðpb.StateRequest{ - StateId: []byte(strconv.FormatUint(1, 10)), - }) - assert.ErrorContains(t, "slot cannot be in the future", err) - }) - - t.Run("Checkpoints not available", func(t *testing.T) { - st, err := testutil.NewBeaconState() - require.NoError(t, err) - err = st.SetPreviousJustifiedCheckpoint(nil) - require.NoError(t, err) - err = st.SetCurrentJustifiedCheckpoint(nil) - require.NoError(t, err) - err = st.SetFinalizedCheckpoint(nil) - require.NoError(t, err) - - s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: st}, - }, - } - - resp, err := s.GetFinalityCheckpoints(ctx, ðpb.StateRequest{ - StateId: []byte("head"), - }) - require.NoError(t, err) - assert.DeepEqual(t, params.BeaconConfig().ZeroHash[:], resp.Data.PreviousJustified.Root) - assert.Equal(t, types.Epoch(0), resp.Data.PreviousJustified.Epoch) - assert.DeepEqual(t, params.BeaconConfig().ZeroHash[:], resp.Data.CurrentJustified.Root) - assert.Equal(t, types.Epoch(0), resp.Data.CurrentJustified.Epoch) - assert.DeepEqual(t, params.BeaconConfig().ZeroHash[:], resp.Data.Finalized.Root) - assert.Equal(t, types.Epoch(0), resp.Data.Finalized.Epoch) - }) + server := &Server{ + StateFetcher: &testutil.MockFetcher{ + BeaconState: fakeState, + }, + } - t.Run("Invalid state", func(t *testing.T) { - s := Server{} - _, err := s.GetFinalityCheckpoints(ctx, ðpb.StateRequest{ - StateId: []byte("foo"), - }) - require.ErrorContains(t, "invalid state ID", err) + resp, err := server.GetFinalityCheckpoints(context.Background(), ðpb.StateRequest{ + StateId: make([]byte, 0), }) + require.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, fakeState.FinalizedCheckpoint().Epoch, resp.Data.Finalized.Epoch) + assert.DeepEqual(t, fakeState.FinalizedCheckpoint().Root, resp.Data.Finalized.Root) + assert.Equal(t, fakeState.CurrentJustifiedCheckpoint().Epoch, resp.Data.CurrentJustified.Epoch) + assert.DeepEqual(t, fakeState.CurrentJustifiedCheckpoint().Root, resp.Data.CurrentJustified.Root) + assert.Equal(t, fakeState.PreviousJustifiedCheckpoint().Epoch, resp.Data.PreviousJustified.Epoch) + assert.DeepEqual(t, fakeState.PreviousJustifiedCheckpoint().Root, resp.Data.PreviousJustified.Root) } diff --git a/beacon-chain/rpc/beaconv1/validator.go b/beacon-chain/rpc/beaconv1/validator.go index 5e86d819d604..6dfd7673cdf2 100644 --- a/beacon-chain/rpc/beaconv1/validator.go +++ b/beacon-chain/rpc/beaconv1/validator.go @@ -9,7 +9,9 @@ import ( types "github.com/prysmaticlabs/eth2-types" ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1" "github.com/prysmaticlabs/prysm/beacon-chain/core/helpers" + "github.com/prysmaticlabs/prysm/beacon-chain/rpc/statefetcher" iface "github.com/prysmaticlabs/prysm/beacon-chain/state/interface" + "github.com/prysmaticlabs/prysm/beacon-chain/state/stateV0" "github.com/prysmaticlabs/prysm/proto/migration" "github.com/prysmaticlabs/prysm/shared/bytesutil" "github.com/prysmaticlabs/prysm/shared/params" @@ -17,18 +19,57 @@ import ( "google.golang.org/grpc/status" ) +// validatorNotFoundError represents an error scenario where a validator could not be found. +type validatorNotFoundError struct { + message string +} + +// newValidatorNotFoundError creates a new error instance. +func newValidatorNotFoundError(validatorId []byte) validatorNotFoundError { + return validatorNotFoundError{ + message: fmt.Sprintf("could not find validator with public key '%#x'", validatorId), + } +} + +// Error returns the underlying error message. +func (e *validatorNotFoundError) Error() string { + return e.message +} + +// invalidValidatorIdError represents an error scenario where a validator's ID is invalid. +type invalidValidatorIdError struct { + message string +} + +// newValidatorNotFoundError creates a new error instance. +func newInvalidValidatorIdError(validatorId []byte, reason error) invalidValidatorIdError { + return invalidValidatorIdError{ + message: fmt.Errorf("could not decode validator id '%s': %w", string(validatorId), reason).Error(), + } +} + +// Error returns the underlying error message. +func (e *invalidValidatorIdError) Error() string { + return e.message +} + // GetValidator returns a validator specified by state and id or public key along with status and balance. func (bs *Server) GetValidator(ctx context.Context, req *ethpb.StateValidatorRequest) (*ethpb.StateValidatorResponse, error) { state, err := bs.StateFetcher.State(ctx, req.StateId) if err != nil { - return nil, status.Errorf(codes.Internal, "Could not get state: %v", err) + if stateNotFoundErr, ok := err.(*statefetcher.StateNotFoundError); ok { + return nil, status.Errorf(codes.NotFound, "could not get state: %v", stateNotFoundErr) + } else if parseErr, ok := err.(*statefetcher.StateIdParseError); ok { + return nil, status.Errorf(codes.InvalidArgument, "Invalid state ID: %v", parseErr) + } + return nil, status.Errorf(codes.Internal, "State not found: %v", err) } if len(req.ValidatorId) == 0 { - return nil, status.Error(codes.Internal, "Must request a validator id") + return nil, status.Error(codes.InvalidArgument, "Validator ID is required") } valContainer, err := valContainersByRequestIds(state, [][]byte{req.ValidatorId}) if err != nil { - return nil, status.Errorf(codes.Internal, "Could not get validator container: %v", err) + return nil, handleValContainerErr(err) } if len(valContainer) == 0 { return nil, status.Error(codes.NotFound, "Could not find validator") @@ -40,12 +81,17 @@ func (bs *Server) GetValidator(ctx context.Context, req *ethpb.StateValidatorReq func (bs *Server) ListValidators(ctx context.Context, req *ethpb.StateValidatorsRequest) (*ethpb.StateValidatorsResponse, error) { state, err := bs.StateFetcher.State(ctx, req.StateId) if err != nil { + if stateNotFoundErr, ok := err.(*statefetcher.StateNotFoundError); ok { + return nil, status.Errorf(codes.NotFound, "State not found: %v", stateNotFoundErr) + } else if parseErr, ok := err.(*statefetcher.StateIdParseError); ok { + return nil, status.Errorf(codes.InvalidArgument, "Invalid state ID: %v", parseErr) + } return nil, status.Errorf(codes.Internal, "Could not get state: %v", err) } valContainers, err := valContainersByRequestIds(state, req.Id) if err != nil { - return nil, status.Errorf(codes.Internal, "Could not get validator container: %v", err) + return nil, handleValContainerErr(err) } if len(req.Status) == 0 { @@ -53,7 +99,11 @@ func (bs *Server) ListValidators(ctx context.Context, req *ethpb.StateValidators } filterStatus := make(map[ethpb.ValidatorStatus]bool, len(req.Status)) + const lastValidStatusValue = ethpb.ValidatorStatus(12) for _, ss := range req.Status { + if ss > lastValidStatusValue { + return nil, status.Errorf(codes.InvalidArgument, "Invalid status "+ss.String()) + } filterStatus[ss] = true } epoch := helpers.SlotToEpoch(state.Slot()) @@ -78,12 +128,17 @@ func (bs *Server) ListValidators(ctx context.Context, req *ethpb.StateValidators func (bs *Server) ListValidatorBalances(ctx context.Context, req *ethpb.ValidatorBalancesRequest) (*ethpb.ValidatorBalancesResponse, error) { state, err := bs.StateFetcher.State(ctx, req.StateId) if err != nil { + if stateNotFoundErr, ok := err.(*statefetcher.StateNotFoundError); ok { + return nil, status.Errorf(codes.NotFound, "State not found: %v", stateNotFoundErr) + } else if parseErr, ok := err.(*statefetcher.StateIdParseError); ok { + return nil, status.Errorf(codes.InvalidArgument, "Invalid state ID: %v", parseErr) + } return nil, status.Errorf(codes.Internal, "Could not get state: %v", err) } valContainers, err := valContainersByRequestIds(state, req.Id) if err != nil { - return nil, status.Errorf(codes.Internal, "Could not get validator: %v", err) + return nil, handleValContainerErr(err) } valBalances := make([]*ethpb.ValidatorBalance, len(valContainers)) for i := 0; i < len(valContainers); i++ { @@ -100,6 +155,11 @@ func (bs *Server) ListValidatorBalances(ctx context.Context, req *ethpb.Validato func (bs *Server) ListCommittees(ctx context.Context, req *ethpb.StateCommitteesRequest) (*ethpb.StateCommitteesResponse, error) { state, err := bs.StateFetcher.State(ctx, req.StateId) if err != nil { + if stateNotFoundErr, ok := err.(*statefetcher.StateNotFoundError); ok { + return nil, status.Errorf(codes.NotFound, "State not found: %v", stateNotFoundErr) + } else if parseErr, ok := err.(*statefetcher.StateIdParseError); ok { + return nil, status.Errorf(codes.InvalidArgument, "Invalid state ID: %v", parseErr) + } return nil, status.Errorf(codes.Internal, "Could not get state: %v", err) } @@ -114,11 +174,11 @@ func (bs *Server) ListCommittees(ctx context.Context, req *ethpb.StateCommittees startSlot, err := helpers.StartSlot(epoch) if err != nil { - return nil, status.Errorf(codes.Internal, "Could not get epoch start slot: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "Invalid epoch: %v", err) } endSlot, err := helpers.EndSlot(epoch) if err != nil { - return nil, status.Errorf(codes.Internal, "Could not get epoch end slot: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "Invalid epoch: %v", err) } committeesPerSlot := helpers.SlotCommitteeCount(activeCount) committees := make([]*ethpb.Committee, 0) @@ -149,10 +209,10 @@ func (bs *Server) ListCommittees(ctx context.Context, req *ethpb.StateCommittees // or its index. func valContainersByRequestIds(state iface.BeaconState, validatorIds [][]byte) ([]*ethpb.ValidatorContainer, error) { epoch := helpers.SlotToEpoch(state.Slot()) - allValidators := state.Validators() - allBalances := state.Balances() var valContainers []*ethpb.ValidatorContainer if len(validatorIds) == 0 { + allValidators := state.Validators() + allBalances := state.Balances() valContainers = make([]*ethpb.ValidatorContainer, len(allValidators)) for i, validator := range allValidators { v1Validator := migration.V1Alpha1ValidatorToV1(validator) @@ -175,28 +235,38 @@ func valContainersByRequestIds(state iface.BeaconState, validatorIds [][]byte) ( var ok bool valIndex, ok = state.ValidatorIndexByPubkey(bytesutil.ToBytes48(validatorId)) if !ok { - return nil, fmt.Errorf("could not find validator with public key: %#x", validatorId) + e := newValidatorNotFoundError(validatorId) + return nil, &e } } else { index, err := strconv.ParseUint(string(validatorId), 10, 64) if err != nil { - return nil, errors.Wrap(err, "could not decode validator id") + e := newInvalidValidatorIdError(validatorId, err) + return nil, &e } valIndex = types.ValidatorIndex(index) } - v1Validator := migration.V1Alpha1ValidatorToV1(allValidators[valIndex]) + validator, err := state.ValidatorAtIndex(valIndex) + if outOfRangeErr, ok := err.(*stateV0.IndexOutOfRangeError); ok { + return nil, outOfRangeErr + } + if err != nil { + return nil, fmt.Errorf("could not get validator: %w", err) + } + v1Validator := migration.V1Alpha1ValidatorToV1(validator) subStatus, err := validatorSubStatus(v1Validator, epoch) if err != nil { return nil, fmt.Errorf("could not get validator sub status: %v", err) } valContainers[i] = ðpb.ValidatorContainer{ Index: valIndex, - Balance: allBalances[valIndex], + Balance: v1Validator.EffectiveBalance, Status: subStatus, Validator: v1Validator, } } } + return valContainers, nil } @@ -260,3 +330,16 @@ func validatorSubStatus(validator *ethpb.Validator, epoch types.Epoch) (ethpb.Va return 0, errors.New("invalid validator state") } + +func handleValContainerErr(err error) error { + if outOfRangeErr, ok := err.(*stateV0.IndexOutOfRangeError); ok { + return status.Errorf(codes.InvalidArgument, "Invalid validator ID: %v", outOfRangeErr) + } + if invalidIdErr, ok := err.(*invalidValidatorIdError); ok { + return status.Errorf(codes.InvalidArgument, "Invalid validator ID: %v", invalidIdErr) + } + if notFoundErr, ok := err.(*validatorNotFoundError); ok { + return status.Errorf(codes.NotFound, "Validator not found: %v", notFoundErr) + } + return status.Errorf(codes.Internal, "Could not get validator container: %v", err) +} diff --git a/beacon-chain/rpc/beaconv1/validator_test.go b/beacon-chain/rpc/beaconv1/validator_test.go index 9734e117a63b..926bcef17d23 100644 --- a/beacon-chain/rpc/beaconv1/validator_test.go +++ b/beacon-chain/rpc/beaconv1/validator_test.go @@ -3,20 +3,20 @@ package beaconv1 import ( "bytes" "context" - "strings" + "strconv" "testing" ethpb_alpha "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1" + chainMock "github.com/prysmaticlabs/prysm/beacon-chain/blockchain/testing" + "github.com/prysmaticlabs/prysm/beacon-chain/rpc/statefetcher" - "github.com/ethereum/go-ethereum/common/hexutil" types "github.com/prysmaticlabs/eth2-types" ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1" - chainMock "github.com/prysmaticlabs/prysm/beacon-chain/blockchain/testing" "github.com/prysmaticlabs/prysm/beacon-chain/core/helpers" - "github.com/prysmaticlabs/prysm/beacon-chain/rpc/statefetcher" + "github.com/prysmaticlabs/prysm/beacon-chain/rpc/testutil" iface "github.com/prysmaticlabs/prysm/beacon-chain/state/interface" "github.com/prysmaticlabs/prysm/shared/params" - "github.com/prysmaticlabs/prysm/shared/testutil" + sharedtestutil "github.com/prysmaticlabs/prysm/shared/testutil" "github.com/prysmaticlabs/prysm/shared/testutil/assert" "github.com/prysmaticlabs/prysm/shared/testutil/require" ) @@ -25,12 +25,12 @@ func TestGetValidator(t *testing.T) { ctx := context.Background() var state iface.BeaconState - state, _ = testutil.DeterministicGenesisState(t, 8192) + state, _ = sharedtestutil.DeterministicGenesisState(t, 8192) t.Run("Head Get Validator by index", func(t *testing.T) { s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, + StateFetcher: &testutil.MockFetcher{ + BeaconState: state, }, } @@ -44,8 +44,8 @@ func TestGetValidator(t *testing.T) { t.Run("Head Get Validator by pubkey", func(t *testing.T) { s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, + StateFetcher: &testutil.MockFetcher{ + BeaconState: state, }, } @@ -59,40 +59,16 @@ func TestGetValidator(t *testing.T) { assert.Equal(t, true, bytes.Equal(pubKey[:], resp.Data.Validator.Pubkey)) }) - t.Run("Hex root not found", func(t *testing.T) { - s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, - }, - } - stateId, err := hexutil.Decode("0x" + strings.Repeat("f", 64)) - require.NoError(t, err) - _, err = s.GetValidator(ctx, ðpb.StateValidatorRequest{ - StateId: stateId, - }) - require.ErrorContains(t, "state not found in the last 8192 state roots", err) - }) - - t.Run("Invalid state ID", func(t *testing.T) { - s := Server{} - pubKey := state.PubkeyAtIndex(types.ValidatorIndex(20)) - _, err := s.GetValidator(ctx, ðpb.StateValidatorRequest{ - StateId: []byte("foo"), - ValidatorId: pubKey[:], - }) - require.ErrorContains(t, "invalid state ID", err) - }) - t.Run("Validator ID required", func(t *testing.T) { s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, + StateFetcher: &testutil.MockFetcher{ + BeaconState: state, }, } _, err := s.GetValidator(ctx, ðpb.StateValidatorRequest{ StateId: []byte("head"), }) - require.ErrorContains(t, "Must request a validator id", err) + require.ErrorContains(t, "Validator ID is required", err) }) } @@ -100,12 +76,12 @@ func TestListValidators(t *testing.T) { ctx := context.Background() var state iface.BeaconState - state, _ = testutil.DeterministicGenesisState(t, 8192) + state, _ = sharedtestutil.DeterministicGenesisState(t, 8192) t.Run("Head List All Validators", func(t *testing.T) { s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, + StateFetcher: &testutil.MockFetcher{ + BeaconState: state, }, } @@ -121,8 +97,8 @@ func TestListValidators(t *testing.T) { t.Run("Head List Validators by index", func(t *testing.T) { s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, + StateFetcher: &testutil.MockFetcher{ + BeaconState: state, }, } @@ -141,8 +117,8 @@ func TestListValidators(t *testing.T) { t.Run("Head List Validators by pubkey", func(t *testing.T) { s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, + StateFetcher: &testutil.MockFetcher{ + BeaconState: state, }, } idNums := []types.ValidatorIndex{20, 66, 90, 100} @@ -165,8 +141,8 @@ func TestListValidators(t *testing.T) { t.Run("Head List Validators by both index and pubkey", func(t *testing.T) { s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, + StateFetcher: &testutil.MockFetcher{ + BeaconState: state, }, } @@ -188,35 +164,13 @@ func TestListValidators(t *testing.T) { assert.Equal(t, ethpb.ValidatorStatus_ACTIVE_ONGOING, val.Status) } }) - - t.Run("Hex root not found", func(t *testing.T) { - s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, - }, - } - stateId, err := hexutil.Decode("0x" + strings.Repeat("f", 64)) - require.NoError(t, err) - _, err = s.ListValidators(ctx, ðpb.StateValidatorsRequest{ - StateId: stateId, - }) - require.ErrorContains(t, "state not found in the last 8192 state roots", err) - }) - - t.Run("Invalid state ID", func(t *testing.T) { - s := Server{} - _, err := s.ListValidators(ctx, ðpb.StateValidatorsRequest{ - StateId: []byte("foo"), - }) - require.ErrorContains(t, "invalid state ID", err) - }) } func TestListValidators_Status(t *testing.T) { ctx := context.Background() var state iface.BeaconState - state, _ = testutil.DeterministicGenesisState(t, 8192) + state, _ = sharedtestutil.DeterministicGenesisState(t, 8192) farFutureEpoch := params.BeaconConfig().FarFutureEpoch validators := []*ethpb_alpha.Validator{ @@ -285,7 +239,7 @@ func TestListValidators_Status(t *testing.T) { t.Run("Head List All ACTIVE Validators", func(t *testing.T) { s := Server{ - StateFetcher: statefetcher.StateProvider{ + StateFetcher: &statefetcher.StateProvider{ ChainInfoFetcher: &chainMock.ChainService{State: state}, }, } @@ -316,7 +270,7 @@ func TestListValidators_Status(t *testing.T) { t.Run("Head List All ACTIVE_ONGOING Validators", func(t *testing.T) { s := Server{ - StateFetcher: statefetcher.StateProvider{ + StateFetcher: &statefetcher.StateProvider{ ChainInfoFetcher: &chainMock.ChainService{State: state}, }, } @@ -346,7 +300,7 @@ func TestListValidators_Status(t *testing.T) { require.NoError(t, state.SetSlot(params.BeaconConfig().SlotsPerEpoch*35)) t.Run("Head List All EXITED Validators", func(t *testing.T) { s := Server{ - StateFetcher: statefetcher.StateProvider{ + StateFetcher: &statefetcher.StateProvider{ ChainInfoFetcher: &chainMock.ChainService{State: state}, }, } @@ -375,7 +329,7 @@ func TestListValidators_Status(t *testing.T) { t.Run("Head List All PENDING_INITIALIZED and EXITED_UNSLASHED Validators", func(t *testing.T) { s := Server{ - StateFetcher: statefetcher.StateProvider{ + StateFetcher: &statefetcher.StateProvider{ ChainInfoFetcher: &chainMock.ChainService{State: state}, }, } @@ -404,7 +358,7 @@ func TestListValidators_Status(t *testing.T) { t.Run("Head List All PENDING and EXITED Validators", func(t *testing.T) { s := Server{ - StateFetcher: statefetcher.StateProvider{ + StateFetcher: &statefetcher.StateProvider{ ChainInfoFetcher: &chainMock.ChainService{State: state}, }, } @@ -437,12 +391,12 @@ func TestListValidatorBalances(t *testing.T) { ctx := context.Background() var state iface.BeaconState - state, _ = testutil.DeterministicGenesisState(t, 8192) + state, _ = sharedtestutil.DeterministicGenesisState(t, 8192) t.Run("Head List Validators Balance by index", func(t *testing.T) { s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, + StateFetcher: &testutil.MockFetcher{ + BeaconState: state, }, } @@ -461,8 +415,8 @@ func TestListValidatorBalances(t *testing.T) { t.Run("Head List Validators Balance by pubkey", func(t *testing.T) { s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, + StateFetcher: &testutil.MockFetcher{ + BeaconState: state, }, } idNums := []types.ValidatorIndex{20, 66, 90, 100} @@ -484,8 +438,8 @@ func TestListValidatorBalances(t *testing.T) { t.Run("Head List Validators Balance by both index and pubkey", func(t *testing.T) { s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, + StateFetcher: &testutil.MockFetcher{ + BeaconState: state, }, } @@ -503,41 +457,19 @@ func TestListValidatorBalances(t *testing.T) { assert.Equal(t, params.BeaconConfig().MaxEffectiveBalance, val.Balance) } }) - - t.Run("Hex root not found", func(t *testing.T) { - s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, - }, - } - stateId, err := hexutil.Decode("0x" + strings.Repeat("f", 64)) - require.NoError(t, err) - _, err = s.ListValidatorBalances(ctx, ðpb.ValidatorBalancesRequest{ - StateId: stateId, - }) - require.ErrorContains(t, "state not found in the last 8192 state roots", err) - }) - - t.Run("Invalid state ID", func(t *testing.T) { - s := Server{} - _, err := s.ListValidatorBalances(ctx, ðpb.ValidatorBalancesRequest{ - StateId: []byte("foo"), - }) - require.ErrorContains(t, "invalid state ID", err) - }) } func TestListCommittees(t *testing.T) { ctx := context.Background() var state iface.BeaconState - state, _ = testutil.DeterministicGenesisState(t, 8192) + state, _ = sharedtestutil.DeterministicGenesisState(t, 8192) epoch := helpers.SlotToEpoch(state.Slot()) t.Run("Head All Committees", func(t *testing.T) { s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, + StateFetcher: &testutil.MockFetcher{ + BeaconState: state, }, } @@ -554,8 +486,8 @@ func TestListCommittees(t *testing.T) { t.Run("Head All Committees of Epoch 10", func(t *testing.T) { s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, + StateFetcher: &testutil.MockFetcher{ + BeaconState: state, }, } epoch := types.Epoch(10) @@ -571,8 +503,8 @@ func TestListCommittees(t *testing.T) { t.Run("Head All Committees of Slot 4", func(t *testing.T) { s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, + StateFetcher: &testutil.MockFetcher{ + BeaconState: state, }, } @@ -594,8 +526,8 @@ func TestListCommittees(t *testing.T) { t.Run("Head All Committees of Index 1", func(t *testing.T) { s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, + StateFetcher: &testutil.MockFetcher{ + BeaconState: state, }, } @@ -617,8 +549,8 @@ func TestListCommittees(t *testing.T) { t.Run("Head All Committees of Slot 2, Index 1", func(t *testing.T) { s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, + StateFetcher: &testutil.MockFetcher{ + BeaconState: state, }, } @@ -637,28 +569,6 @@ func TestListCommittees(t *testing.T) { assert.Equal(t, index, datum.Index) } }) - - t.Run("Hex root not found", func(t *testing.T) { - s := Server{ - StateFetcher: statefetcher.StateProvider{ - ChainInfoFetcher: &chainMock.ChainService{State: state}, - }, - } - stateId, err := hexutil.Decode("0x" + strings.Repeat("f", 64)) - require.NoError(t, err) - _, err = s.ListCommittees(ctx, ðpb.StateCommitteesRequest{ - StateId: stateId, - }) - require.ErrorContains(t, "state not found in the last 8192 state roots", err) - }) - - t.Run("Invalid state ID", func(t *testing.T) { - s := Server{} - _, err := s.ListCommittees(ctx, ðpb.StateCommitteesRequest{ - StateId: []byte("foo"), - }) - require.ErrorContains(t, "invalid state ID", err) - }) } func Test_validatorStatus(t *testing.T) { @@ -932,3 +842,15 @@ func Test_validatorSubStatus(t *testing.T) { }) } } + +// This test verifies how many validator statuses have meaningful values. +// The first expected non-meaningful value will have x.String() equal to its numeric representation. +// This test assumes we start numbering from 0 and do not skip any values. +// Having a test like this allows us to use e.g. `if value < 10` for validity checks. +func TestNumberOfStatuses(t *testing.T) { + lastValidEnumValue := 12 + x := ethpb.ValidatorStatus(lastValidEnumValue) + assert.NotEqual(t, strconv.Itoa(lastValidEnumValue), x.String()) + x = ethpb.ValidatorStatus(lastValidEnumValue + 1) + assert.Equal(t, strconv.Itoa(lastValidEnumValue+1), x.String()) +} diff --git a/beacon-chain/rpc/debugv1/BUILD.bazel b/beacon-chain/rpc/debugv1/BUILD.bazel index dcf10b821e21..6a3cce518f18 100644 --- a/beacon-chain/rpc/debugv1/BUILD.bazel +++ b/beacon-chain/rpc/debugv1/BUILD.bazel @@ -13,7 +13,6 @@ go_library( "//beacon-chain/blockchain:go_default_library", "//beacon-chain/db:go_default_library", "//beacon-chain/rpc/statefetcher:go_default_library", - "@com_github_pkg_errors//:go_default_library", "@com_github_prysmaticlabs_ethereumapis//eth/v1:go_default_library", "@io_opencensus_go//trace:go_default_library", "@org_golang_google_grpc//codes:go_default_library", diff --git a/beacon-chain/rpc/debugv1/debug.go b/beacon-chain/rpc/debugv1/debug.go index 93af82d5cf11..5c28c834b789 100644 --- a/beacon-chain/rpc/debugv1/debug.go +++ b/beacon-chain/rpc/debugv1/debug.go @@ -3,7 +3,6 @@ package debugv1 import ( "context" - "github.com/pkg/errors" ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1" "github.com/prysmaticlabs/prysm/beacon-chain/rpc/statefetcher" "go.opencensus.io/trace" @@ -20,16 +19,16 @@ func (ds *Server) GetBeaconState(ctx context.Context, req *ethpb.StateRequest) ( state, err := ds.StateFetcher.State(ctx, req.StateId) if err != nil { if stateNotFoundErr, ok := err.(*statefetcher.StateNotFoundError); ok { - return nil, status.Errorf(codes.NotFound, "could not get state: %v", stateNotFoundErr) - } else if errors.Is(err, statefetcher.ErrInvalidStateId) { - return nil, status.Errorf(codes.InvalidArgument, "could not get state: %v", err) + return nil, status.Errorf(codes.NotFound, "State not found: %v", stateNotFoundErr) + } else if parseErr, ok := err.(*statefetcher.StateIdParseError); ok { + return nil, status.Errorf(codes.InvalidArgument, "Invalid state ID: %v", parseErr) } - return nil, status.Errorf(codes.Internal, "could not get state: %v", err) + return nil, status.Errorf(codes.Internal, "Invalid state ID: %v", err) } protoState, err := state.ToProto() if err != nil { - return nil, status.Errorf(codes.Internal, "could not convert state to proto: %v", err) + return nil, status.Errorf(codes.Internal, "Could not convert state to proto: %v", err) } return ðpb.BeaconStateResponse{ @@ -44,12 +43,17 @@ func (ds *Server) GetBeaconStateSsz(ctx context.Context, req *ethpb.StateRequest state, err := ds.StateFetcher.State(ctx, req.StateId) if err != nil { - return nil, status.Errorf(codes.Internal, "could not get state: %v", err) + if stateNotFoundErr, ok := err.(*statefetcher.StateNotFoundError); ok { + return nil, status.Errorf(codes.NotFound, "State not found: %v", stateNotFoundErr) + } else if parseErr, ok := err.(*statefetcher.StateIdParseError); ok { + return nil, status.Errorf(codes.InvalidArgument, "Invalid state ID: %v", parseErr) + } + return nil, status.Errorf(codes.Internal, "Invalid state ID: %v", err) } sszState, err := state.MarshalSSZ() if err != nil { - return nil, status.Errorf(codes.Internal, "could not marshal state into SSZ: %v", err) + return nil, status.Errorf(codes.Internal, "Could not marshal state into SSZ: %v", err) } return ðpb.BeaconStateSszResponse{Data: sszState}, nil diff --git a/beacon-chain/rpc/nodev1/node.go b/beacon-chain/rpc/nodev1/node.go index cf061ce30f50..2c6c6ab95442 100644 --- a/beacon-chain/rpc/nodev1/node.go +++ b/beacon-chain/rpc/nodev1/node.go @@ -44,7 +44,7 @@ func (ns *Server) GetIdentity(ctx context.Context, _ *emptypb.Empty) (*ethpb.Ide serializedEnr, err := p2p.SerializeENR(ns.PeerManager.ENR()) if err != nil { - return nil, status.Errorf(codes.Internal, "could not obtain enr: %v", err) + return nil, status.Errorf(codes.Internal, "Could not obtain enr: %v", err) } enr := "enr:" + serializedEnr @@ -56,7 +56,7 @@ func (ns *Server) GetIdentity(ctx context.Context, _ *emptypb.Empty) (*ethpb.Ide sourceDisc, err := ns.PeerManager.DiscoveryAddresses() if err != nil { - return nil, status.Errorf(codes.Internal, "could not obtain discovery address: %v", err) + return nil, status.Errorf(codes.Internal, "Could not obtain discovery address: %v", err) } discoveryAddresses := make([]string, len(sourceDisc)) for i := range sourceDisc { @@ -87,7 +87,7 @@ func (ns *Server) GetPeer(ctx context.Context, req *ethpb.PeerRequest) (*ethpb.P peerStatus := ns.PeersFetcher.Peers() id, err := peer.Decode(req.PeerId) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "Could not decode peer ID: %v", err) + return nil, status.Errorf(codes.InvalidArgument, "Invalid peer ID: %v", err) } enr, err := peerStatus.ENR(id) if err != nil { @@ -357,7 +357,7 @@ func peerInfo(peerStatus *peers.Status, id peer.ID) (*ethpb.Peer, error) { v1ConnState := migration.V1Alpha1ConnectionStateToV1(ethpb_alpha.ConnectionState(connectionState)) v1PeerDirection, err := migration.V1Alpha1PeerDirectionToV1(ethpb_alpha.PeerDirection(direction)) if err != nil { - return nil, status.Errorf(codes.Internal, "Could not handle peer direction: %v", err) + return nil, fmt.Errorf("could not handle peer direction: %w", err) } p := ethpb.Peer{ PeerId: id.Pretty(), diff --git a/beacon-chain/rpc/nodev1/node_test.go b/beacon-chain/rpc/nodev1/node_test.go index d77792bce71e..ebc7cec57d78 100644 --- a/beacon-chain/rpc/nodev1/node_test.go +++ b/beacon-chain/rpc/nodev1/node_test.go @@ -138,7 +138,7 @@ func TestGetIdentity(t *testing.T) { } _, err = s.GetIdentity(ctx, &emptypb.Empty{}) - assert.ErrorContains(t, "could not obtain enr", err) + assert.ErrorContains(t, "Could not obtain enr", err) }) t.Run("Discovery addresses failure", func(t *testing.T) { @@ -155,7 +155,7 @@ func TestGetIdentity(t *testing.T) { } _, err = s.GetIdentity(ctx, &emptypb.Empty{}) - assert.ErrorContains(t, "could not obtain discovery address", err) + assert.ErrorContains(t, "Could not obtain discovery address", err) }) } @@ -212,7 +212,7 @@ func TestGetPeer(t *testing.T) { t.Run("Invalid ID", func(t *testing.T) { _, err = s.GetPeer(ctx, ðpb.PeerRequest{PeerId: "foo"}) - assert.ErrorContains(t, "Could not decode peer ID", err) + assert.ErrorContains(t, "Invalid peer ID", err) }) t.Run("Peer not found", func(t *testing.T) { diff --git a/beacon-chain/rpc/service.go b/beacon-chain/rpc/service.go index 20fd5848afc6..e4f166236e41 100644 --- a/beacon-chain/rpc/service.go +++ b/beacon-chain/rpc/service.go @@ -247,7 +247,7 @@ func (s *Service) Start() { Broadcaster: s.cfg.Broadcaster, BlockReceiver: s.cfg.BlockReceiver, StateGenService: s.cfg.StateGen, - StateFetcher: statefetcher.StateProvider{ + StateFetcher: &statefetcher.StateProvider{ BeaconDB: s.cfg.BeaconDB, ChainInfoFetcher: s.cfg.ChainInfoFetcher, GenesisTimeFetcher: s.cfg.GenesisTimeFetcher, diff --git a/beacon-chain/rpc/statefetcher/BUILD.bazel b/beacon-chain/rpc/statefetcher/BUILD.bazel index 9222f6dd7a73..c4aceb5d063e 100644 --- a/beacon-chain/rpc/statefetcher/BUILD.bazel +++ b/beacon-chain/rpc/statefetcher/BUILD.bazel @@ -8,6 +8,7 @@ go_library( visibility = ["//beacon-chain:__subpackages__"], deps = [ "//beacon-chain/blockchain:go_default_library", + "//beacon-chain/core/helpers:go_default_library", "//beacon-chain/db:go_default_library", "//beacon-chain/state/interface:go_default_library", "//beacon-chain/state/stategen:go_default_library", diff --git a/beacon-chain/rpc/statefetcher/fetcher.go b/beacon-chain/rpc/statefetcher/fetcher.go index 77a16dde0551..6fa4bcd04afe 100644 --- a/beacon-chain/rpc/statefetcher/fetcher.go +++ b/beacon-chain/rpc/statefetcher/fetcher.go @@ -10,14 +10,29 @@ import ( "github.com/pkg/errors" types "github.com/prysmaticlabs/eth2-types" "github.com/prysmaticlabs/prysm/beacon-chain/blockchain" + "github.com/prysmaticlabs/prysm/beacon-chain/core/helpers" "github.com/prysmaticlabs/prysm/beacon-chain/db" iface "github.com/prysmaticlabs/prysm/beacon-chain/state/interface" "github.com/prysmaticlabs/prysm/beacon-chain/state/stategen" "github.com/prysmaticlabs/prysm/shared/bytesutil" ) -// ErrInvalidStateId represents an error scenario where a state ID is invalid. -var ErrInvalidStateId = errors.New("invalid state ID") +// StateIdParseError represents an error scenario where a state ID could not be parsed. +type StateIdParseError struct { + message string +} + +// NewStateIdParseError creates a new error instance. +func NewStateIdParseError(reason error) StateIdParseError { + return StateIdParseError{ + message: fmt.Sprintf("could not parse state ID: %v", reason), + } +} + +// Error returns the underlying error message. +func (e *StateIdParseError) Error() string { + return e.message +} // StateNotFoundError represents an error scenario where a state could not be found. type StateNotFoundError struct { @@ -36,9 +51,27 @@ func (e *StateNotFoundError) Error() string { return e.message } -// Fetcher is responsible for retrieving the BeaconState. +// StateRootNotFoundError represents an error scenario where a state root could not be found. +type StateRootNotFoundError struct { + message string +} + +// NewStateRootNotFoundError creates a new error instance. +func NewStateRootNotFoundError(stateRootsSize int) StateNotFoundError { + return StateNotFoundError{ + message: fmt.Sprintf("state root not found in the last %d state roots", stateRootsSize), + } +} + +// Error returns the underlying error message. +func (e *StateRootNotFoundError) Error() string { + return e.message +} + +// Fetcher is responsible for retrieving info related with the beacon chain. type Fetcher interface { State(ctx context.Context, stateId []byte) (iface.BeaconState, error) + StateRoot(ctx context.Context, stateId []byte) ([]byte, error) } // StateProvider is a real implementation of Fetcher. @@ -55,7 +88,7 @@ type StateProvider struct { // - "finalized" // - "justified" // - -// - +// - func (p *StateProvider) State(ctx context.Context, stateId []byte) (iface.BeaconState, error) { var ( s iface.BeaconState @@ -93,7 +126,8 @@ func (p *StateProvider) State(ctx context.Context, stateId []byte) (iface.Beacon slotNumber, parseErr := strconv.ParseUint(stateIdString, 10, 64) if parseErr != nil { // ID format does not match any valid options. - return nil, ErrInvalidStateId + e := NewStateIdParseError(parseErr) + return nil, &e } s, err = p.stateBySlot(ctx, types.Slot(slotNumber)) } @@ -102,6 +136,46 @@ func (p *StateProvider) State(ctx context.Context, stateId []byte) (iface.Beacon return s, err } +// StateRoot returns a beacon state root for a given identifier. The identifier can be one of: +// - "head" (canonical head in node's view) +// - "genesis" +// - "finalized" +// - "justified" +// - +// - +func (p *StateProvider) StateRoot(ctx context.Context, stateId []byte) ([]byte, error) { + var ( + root []byte + err error + ) + + stateIdString := strings.ToLower(string(stateId)) + switch stateIdString { + case "head": + root, err = p.headStateRoot(ctx) + case "genesis": + root, err = p.genesisStateRoot(ctx) + case "finalized": + root, err = p.finalizedStateRoot(ctx) + case "justified": + root, err = p.justifiedStateRoot(ctx) + default: + if len(stateId) == 32 { + root, err = p.stateRootByHex(ctx, stateId) + } else { + slotNumber, parseErr := strconv.ParseUint(stateIdString, 10, 64) + if parseErr != nil { + e := NewStateIdParseError(parseErr) + // ID format does not match any valid options. + return nil, &e + } + root, err = p.stateRootBySlot(ctx, types.Slot(slotNumber)) + } + } + + return root, err +} + func (p *StateProvider) stateByHex(ctx context.Context, stateId []byte) (iface.BeaconState, error) { headState, err := p.ChainInfoFetcher.HeadState(ctx) if err != nil { @@ -129,3 +203,93 @@ func (p *StateProvider) stateBySlot(ctx context.Context, slot types.Slot) (iface } return state, nil } + +func (p *StateProvider) headStateRoot(ctx context.Context) ([]byte, error) { + b, err := p.ChainInfoFetcher.HeadBlock(ctx) + if err != nil { + return nil, errors.Wrap(err, "could not get head block") + } + if err := helpers.VerifyNilBeaconBlock(b); err != nil { + return nil, err + } + return b.Block().StateRoot(), nil +} + +func (p *StateProvider) genesisStateRoot(ctx context.Context) ([]byte, error) { + b, err := p.BeaconDB.GenesisBlock(ctx) + if err != nil { + return nil, errors.Wrap(err, "could not get genesis block") + } + if err := helpers.VerifyNilBeaconBlock(b); err != nil { + return nil, err + } + return b.Block().StateRoot(), nil +} + +func (p *StateProvider) finalizedStateRoot(ctx context.Context) ([]byte, error) { + cp, err := p.BeaconDB.FinalizedCheckpoint(ctx) + if err != nil { + return nil, errors.Wrap(err, "could not get finalized checkpoint") + } + b, err := p.BeaconDB.Block(ctx, bytesutil.ToBytes32(cp.Root)) + if err != nil { + return nil, errors.Wrap(err, "could not get finalized block") + } + if err := helpers.VerifyNilBeaconBlock(b); err != nil { + return nil, err + } + return b.Block().StateRoot(), nil +} + +func (p *StateProvider) justifiedStateRoot(ctx context.Context) ([]byte, error) { + cp, err := p.BeaconDB.JustifiedCheckpoint(ctx) + if err != nil { + return nil, errors.Wrap(err, "could not get justified checkpoint") + } + b, err := p.BeaconDB.Block(ctx, bytesutil.ToBytes32(cp.Root)) + if err != nil { + return nil, errors.Wrap(err, "could not get justified block") + } + if err := helpers.VerifyNilBeaconBlock(b); err != nil { + return nil, err + } + return b.Block().StateRoot(), nil +} + +func (p *StateProvider) stateRootByHex(ctx context.Context, stateId []byte) ([]byte, error) { + var stateRoot [32]byte + copy(stateRoot[:], stateId) + headState, err := p.ChainInfoFetcher.HeadState(ctx) + if err != nil { + return nil, errors.Wrap(err, "could not get head state") + } + for _, root := range headState.StateRoots() { + if bytes.Equal(root, stateRoot[:]) { + return stateRoot[:], nil + } + } + + rootNotFoundErr := NewStateRootNotFoundError(len(headState.StateRoots())) + return nil, &rootNotFoundErr +} + +func (p *StateProvider) stateRootBySlot(ctx context.Context, slot types.Slot) ([]byte, error) { + currentSlot := p.GenesisTimeFetcher.CurrentSlot() + if slot > currentSlot { + return nil, errors.New("slot cannot be in the future") + } + found, blks, err := p.BeaconDB.BlocksBySlot(ctx, slot) + if err != nil { + return nil, errors.Wrap(err, "could not get blocks") + } + if !found { + return nil, errors.New("no block exists") + } + if len(blks) != 1 { + return nil, errors.New("multiple blocks exist in same slot") + } + if blks[0] == nil || blks[0].Block() == nil { + return nil, errors.New("nil block") + } + return blks[0].Block().StateRoot(), nil +} diff --git a/beacon-chain/rpc/statefetcher/fetcher_test.go b/beacon-chain/rpc/statefetcher/fetcher_test.go index 21cbd6dfa349..911f47563139 100644 --- a/beacon-chain/rpc/statefetcher/fetcher_test.go +++ b/beacon-chain/rpc/statefetcher/fetcher_test.go @@ -22,7 +22,7 @@ import ( "github.com/prysmaticlabs/prysm/shared/testutil/require" ) -func TestGetStateRoot(t *testing.T) { +func TestGetState(t *testing.T) { ctx := context.Background() headSlot := types.Slot(123) @@ -180,7 +180,192 @@ func TestGetStateRoot(t *testing.T) { t.Run("Invalid state", func(t *testing.T) { p := StateProvider{} _, err := p.State(ctx, []byte("foo")) - require.ErrorContains(t, "invalid state ID", err) + require.ErrorContains(t, "could not parse state ID", err) + }) +} + +func TestGetStateRoot(t *testing.T) { + ctx := context.Background() + + headSlot := types.Slot(123) + fillSlot := func(state *pb.BeaconState) error { + state.Slot = headSlot + return nil + } + state, err := testutil.NewBeaconState(testutil.FillRootsNaturalOpt, fillSlot) + require.NoError(t, err) + stateRoot, err := state.HashTreeRoot(ctx) + require.NoError(t, err) + + t.Run("Head", func(t *testing.T) { + b := testutil.NewBeaconBlock() + b.Block.StateRoot = stateRoot[:] + p := StateProvider{ + ChainInfoFetcher: &chainMock.ChainService{ + State: state, + Block: interfaces.WrappedPhase0SignedBeaconBlock(b), + }, + } + + s, err := p.StateRoot(ctx, []byte("head")) + require.NoError(t, err) + assert.DeepEqual(t, stateRoot[:], s) + }) + + t.Run("Genesis", func(t *testing.T) { + db := testDB.SetupDB(t) + b := testutil.NewBeaconBlock() + require.NoError(t, db.SaveBlock(ctx, interfaces.WrappedPhase0SignedBeaconBlock(b))) + r, err := b.Block.HashTreeRoot() + require.NoError(t, err) + + state, err := testutil.NewBeaconState(func(state *pb.BeaconState) error { + state.BlockRoots[0] = r[:] + return nil + }) + require.NoError(t, err) + + require.NoError(t, db.SaveStateSummary(ctx, &pb.StateSummary{Root: r[:]})) + require.NoError(t, db.SaveGenesisBlockRoot(ctx, r)) + require.NoError(t, db.SaveState(ctx, state, r)) + + p := StateProvider{ + BeaconDB: db, + } + + s, err := p.StateRoot(ctx, []byte("genesis")) + require.NoError(t, err) + genesisBlock, err := db.GenesisBlock(ctx) + require.NoError(t, err) + assert.DeepEqual(t, genesisBlock.Block().StateRoot(), s) + }) + + t.Run("Finalized", func(t *testing.T) { + db := testDB.SetupDB(t) + genesis := bytesutil.ToBytes32([]byte("genesis")) + require.NoError(t, db.SaveGenesisBlockRoot(ctx, genesis)) + blk := testutil.NewBeaconBlock() + blk.Block.ParentRoot = genesis[:] + blk.Block.Slot = 40 + root, err := blk.Block.HashTreeRoot() + require.NoError(t, err) + cp := ð.Checkpoint{ + Epoch: 5, + Root: root[:], + } + // a valid chain is required to save finalized checkpoint. + require.NoError(t, db.SaveBlock(ctx, interfaces.WrappedPhase0SignedBeaconBlock(blk))) + st, err := testutil.NewBeaconState() + require.NoError(t, err) + require.NoError(t, st.SetSlot(1)) + // a state is required to save checkpoint + require.NoError(t, db.SaveState(ctx, st, root)) + require.NoError(t, db.SaveFinalizedCheckpoint(ctx, cp)) + + p := StateProvider{ + BeaconDB: db, + } + + s, err := p.StateRoot(ctx, []byte("finalized")) + require.NoError(t, err) + assert.DeepEqual(t, blk.Block.StateRoot, s) + }) + + t.Run("Justified", func(t *testing.T) { + db := testDB.SetupDB(t) + genesis := bytesutil.ToBytes32([]byte("genesis")) + require.NoError(t, db.SaveGenesisBlockRoot(ctx, genesis)) + blk := testutil.NewBeaconBlock() + blk.Block.ParentRoot = genesis[:] + blk.Block.Slot = 40 + root, err := blk.Block.HashTreeRoot() + require.NoError(t, err) + cp := ð.Checkpoint{ + Epoch: 5, + Root: root[:], + } + // a valid chain is required to save finalized checkpoint. + require.NoError(t, db.SaveBlock(ctx, interfaces.WrappedPhase0SignedBeaconBlock(blk))) + st, err := testutil.NewBeaconState() + require.NoError(t, err) + require.NoError(t, st.SetSlot(1)) + // a state is required to save checkpoint + require.NoError(t, db.SaveState(ctx, st, root)) + require.NoError(t, db.SaveJustifiedCheckpoint(ctx, cp)) + + p := StateProvider{ + BeaconDB: db, + } + + s, err := p.StateRoot(ctx, []byte("justified")) + require.NoError(t, err) + assert.DeepEqual(t, blk.Block.StateRoot, s) + }) + + t.Run("Hex root", func(t *testing.T) { + stateId, err := hexutil.Decode("0x" + strings.Repeat("0", 63) + "1") + require.NoError(t, err) + + p := StateProvider{ + ChainInfoFetcher: &chainMock.ChainService{State: state}, + } + + s, err := p.StateRoot(ctx, stateId) + require.NoError(t, err) + assert.DeepEqual(t, stateId, s) + }) + + t.Run("Hex root not found", func(t *testing.T) { + p := StateProvider{ + ChainInfoFetcher: &chainMock.ChainService{State: state}, + } + stateId, err := hexutil.Decode("0x" + strings.Repeat("f", 64)) + require.NoError(t, err) + _, err = p.StateRoot(ctx, stateId) + require.ErrorContains(t, "state root not found in the last 8192 state roots", err) + }) + + t.Run("Slot", func(t *testing.T) { + db := testDB.SetupDB(t) + genesis := bytesutil.ToBytes32([]byte("genesis")) + require.NoError(t, db.SaveGenesisBlockRoot(ctx, genesis)) + blk := testutil.NewBeaconBlock() + blk.Block.ParentRoot = genesis[:] + blk.Block.Slot = 40 + root, err := blk.Block.HashTreeRoot() + require.NoError(t, err) + require.NoError(t, db.SaveBlock(ctx, interfaces.WrappedPhase0SignedBeaconBlock(blk))) + st, err := testutil.NewBeaconState() + require.NoError(t, err) + require.NoError(t, st.SetSlot(1)) + // a state is required to save checkpoint + require.NoError(t, db.SaveState(ctx, st, root)) + + slot := types.Slot(40) + p := StateProvider{ + GenesisTimeFetcher: &chainMock.ChainService{Slot: &slot}, + BeaconDB: db, + } + + s, err := p.StateRoot(ctx, []byte(strconv.FormatUint(uint64(slot), 10))) + require.NoError(t, err) + assert.DeepEqual(t, blk.Block.StateRoot, s) + }) + + t.Run("Slot too big", func(t *testing.T) { + p := StateProvider{ + GenesisTimeFetcher: &chainMock.ChainService{ + Genesis: time.Now(), + }, + } + _, err := p.StateRoot(ctx, []byte(strconv.FormatUint(1, 10))) + assert.ErrorContains(t, "slot cannot be in the future", err) + }) + + t.Run("Invalid state", func(t *testing.T) { + p := StateProvider{} + _, err := p.StateRoot(ctx, []byte("foo")) + require.ErrorContains(t, "could not parse state ID", err) }) } diff --git a/beacon-chain/rpc/testutil/mock_state_fetcher.go b/beacon-chain/rpc/testutil/mock_state_fetcher.go index 58b36639a9a3..c018dc9ea8e4 100644 --- a/beacon-chain/rpc/testutil/mock_state_fetcher.go +++ b/beacon-chain/rpc/testutil/mock_state_fetcher.go @@ -8,10 +8,16 @@ import ( // MockFetcher is a fake implementation of statefetcher.Fetcher. type MockFetcher struct { - BeaconState iface.BeaconState + BeaconState iface.BeaconState + BeaconStateRoot []byte } // State -- func (m *MockFetcher) State(context.Context, []byte) (iface.BeaconState, error) { return m.BeaconState, nil } + +// State -- +func (m *MockFetcher) StateRoot(context.Context, []byte) ([]byte, error) { + return m.BeaconStateRoot, nil +} diff --git a/beacon-chain/state/stateV0/getters_validator.go b/beacon-chain/state/stateV0/getters_validator.go index 4c885adc8999..36537523d8a4 100644 --- a/beacon-chain/state/stateV0/getters_validator.go +++ b/beacon-chain/state/stateV0/getters_validator.go @@ -19,6 +19,24 @@ import ( "github.com/prysmaticlabs/prysm/shared/params" ) +// IndexOutOfRangeError represents an error scenario where a validator does not exist +// at a given index in the validator's array. +type IndexOutOfRangeError struct { + message string +} + +// NewStateNotFoundError creates a new error instance. +func NewIndexOutOfRangeError(index types.ValidatorIndex) IndexOutOfRangeError { + return IndexOutOfRangeError{ + message: fmt.Sprintf("index %d out of range", index), + } +} + +// Error returns the underlying error message. +func (e *IndexOutOfRangeError) Error() string { + return e.message +} + // Validators participating in consensus on the beacon chain. func (b *BeaconState) Validators() []*ethpb.Validator { if !b.hasInnerState() { @@ -87,7 +105,8 @@ func (b *BeaconState) ValidatorAtIndex(idx types.ValidatorIndex) (*ethpb.Validat return ðpb.Validator{}, nil } if uint64(len(b.state.Validators)) <= uint64(idx) { - return nil, fmt.Errorf("index %d out of range", idx) + e := NewIndexOutOfRangeError(idx) + return nil, &e } b.lock.RLock() @@ -107,7 +126,8 @@ func (b *BeaconState) ValidatorAtIndexReadOnly(idx types.ValidatorIndex) (iface. return ReadOnlyValidator{}, nil } if uint64(len(b.state.Validators)) <= uint64(idx) { - return ReadOnlyValidator{}, fmt.Errorf("index %d out of range", idx) + e := NewIndexOutOfRangeError(idx) + return ReadOnlyValidator{}, &e } b.lock.RLock()