diff --git a/beacon-chain/rpc/validator/server.go b/beacon-chain/rpc/validator/server.go index 0984c22015a6..f59ab97ef64a 100644 --- a/beacon-chain/rpc/validator/server.go +++ b/beacon-chain/rpc/validator/server.go @@ -189,6 +189,19 @@ func (vs *Server) WaitForChainStart(req *ptypes.Empty, stream ethpb.BeaconNodeVa } return stream.Send(res) } + // Handle race condition in the event the blockchain + // service isn't initialized in time and the saved head state is nil. + if event.Type == statefeed.Initialized { + data, ok := event.Data.(*statefeed.InitializedData) + if !ok { + return errors.New("event data is not type *statefeed.InitializedData") + } + res := ðpb.ChainStartResponse{ + Started: true, + GenesisTime: uint64(data.StartTime.Unix()), + } + return stream.Send(res) + } case <-stateSub.Err(): return status.Error(codes.Aborted, "Subscriber closed, exiting goroutine") case <-vs.Ctx.Done(): diff --git a/beacon-chain/rpc/validator/server_test.go b/beacon-chain/rpc/validator/server_test.go index 83af2ce8e636..4cd250eebcdc 100644 --- a/beacon-chain/rpc/validator/server_test.go +++ b/beacon-chain/rpc/validator/server_test.go @@ -2,6 +2,7 @@ package validator import ( "context" + "sync" "testing" "time" @@ -337,6 +338,44 @@ func TestWaitForChainStart_AlreadyStarted(t *testing.T) { assert.NoError(t, Server.WaitForChainStart(&ptypes.Empty{}, mockStream), "Could not call RPC method") } +func TestWaitForChainStart_HeadStateDoesNotExist(t *testing.T) { + db, _ := dbutil.SetupDB(t) + genesisValidatorRoot := [32]byte{0x01, 0x02} + + // Set head state to nil + chainService := &mockChain.ChainService{State: nil} + notifier := chainService.StateNotifier() + Server := &Server{ + Ctx: context.Background(), + ChainStartFetcher: &mockPOW.POWChain{ + ChainFeed: new(event.Feed), + }, + BeaconDB: db, + StateNotifier: chainService.StateNotifier(), + HeadFetcher: chainService, + } + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockStream := mock.NewMockBeaconNodeValidator_WaitForChainStartServer(ctrl) + + wg := new(sync.WaitGroup) + wg.Add(1) + go func() { + assert.NoError(t, Server.WaitForChainStart(&ptypes.Empty{}, mockStream), "Could not call RPC method") + wg.Done() + }() + // Simulate a late state initialization event, so that + // method is able to handle race condition here. + notifier.StateFeed().Send(&feed.Event{ + Type: statefeed.Initialized, + Data: &statefeed.InitializedData{ + StartTime: time.Unix(0, 0), + GenesisValidatorsRoot: genesisValidatorRoot[:], + }, + }) + testutil.WaitTimeout(wg, time.Second) +} + func TestWaitForChainStart_NotStartedThenLogFired(t *testing.T) { db, _ := dbutil.SetupDB(t)