diff --git a/beacon-chain/core/state/transition.go b/beacon-chain/core/state/transition.go index 2d59a29f70ed..b402f758f779 100644 --- a/beacon-chain/core/state/transition.go +++ b/beacon-chain/core/state/transition.go @@ -72,7 +72,7 @@ func ExecuteStateTransition( // Execute per epoch transition. if e.CanProcessEpoch(state) { - state, err = ProcessEpoch(ctx, state, config) + state, err = ProcessEpoch(ctx, state, block, config) } if err != nil { return nil, fmt.Errorf("could not process epoch: %v", err) @@ -195,7 +195,7 @@ func ProcessBlock( // process_crosslink_reward_penalties(state) // update_validator_registry(state) // final_book_keeping(state) -func ProcessEpoch(ctx context.Context, state *pb.BeaconState, config *TransitionConfig) (*pb.BeaconState, error) { +func ProcessEpoch(ctx context.Context, state *pb.BeaconState, block *pb.BeaconBlock, config *TransitionConfig) (*pb.BeaconState, error) { ctx, span := trace.StartSpan(ctx, "beacon-chain.ChainService.state.ProcessEpoch") defer span.End() @@ -400,9 +400,11 @@ func ProcessEpoch(ctx context.Context, state *pb.BeaconState, config *Transition state = e.ProcessPrevSlotShardSeed(state) state = v.ProcessPenaltiesAndExits(state) if e.CanProcessValidatorRegistry(state) { - state, err = v.UpdateRegistry(state) - if err != nil { - return nil, fmt.Errorf("could not update validator registry: %v", err) + if block != nil { + state, err = v.UpdateRegistry(state) + if err != nil { + return nil, fmt.Errorf("could not update validator registry: %v", err) + } } state, err = e.ProcessCurrSlotShardSeed(state) if err != nil { diff --git a/beacon-chain/core/state/transition_test.go b/beacon-chain/core/state/transition_test.go index 70b672f7b589..bd9bc1d72e97 100644 --- a/beacon-chain/core/state/transition_test.go +++ b/beacon-chain/core/state/transition_test.go @@ -457,12 +457,83 @@ func TestProcessEpoch_PassesProcessingConditions(t *testing.T) { params.BeaconConfig().LatestSlashedExitLength), } - _, err := state.ProcessEpoch(context.Background(), newState, state.DefaultConfig()) + _, err := state.ProcessEpoch(context.Background(), newState, &pb.BeaconBlock{}, state.DefaultConfig()) if err != nil { t.Errorf("Expected epoch transition to pass processing conditions: %v", err) } } +func TestProcessEpoch_PreventsRegistryUpdateOnNilBlock(t *testing.T) { + featureconfig.InitFeatureConfig(&featureconfig.FeatureFlagConfig{ + EnableCrosslinks: false, + }) + var validatorRegistry []*pb.Validator + for i := uint64(0); i < 10; i++ { + validatorRegistry = append(validatorRegistry, + &pb.Validator{ + ExitEpoch: params.BeaconConfig().FarFutureEpoch, + }) + } + validatorBalances := make([]uint64, len(validatorRegistry)) + for i := 0; i < len(validatorBalances); i++ { + validatorBalances[i] = params.BeaconConfig().MaxDepositAmount + } + + var attestations []*pb.PendingAttestation + for i := uint64(0); i < params.BeaconConfig().SlotsPerEpoch*2; i++ { + attestations = append(attestations, &pb.PendingAttestation{ + Data: &pb.AttestationData{ + Slot: i + params.BeaconConfig().SlotsPerEpoch + params.BeaconConfig().GenesisSlot, + Shard: 1, + JustifiedEpoch: params.BeaconConfig().GenesisEpoch + 1, + JustifiedBlockRootHash32: []byte{0}, + }, + InclusionSlot: i + params.BeaconConfig().SlotsPerEpoch + 1 + params.BeaconConfig().GenesisSlot, + }) + } + + var blockRoots [][]byte + for i := uint64(0); i < params.BeaconConfig().LatestBlockRootsLength; i++ { + blockRoots = append(blockRoots, []byte{byte(i)}) + } + + var randaoHashes [][]byte + for i := uint64(0); i < params.BeaconConfig().SlotsPerEpoch; i++ { + randaoHashes = append(randaoHashes, []byte{byte(i)}) + } + + crosslinkRecord := make([]*pb.Crosslink, 64) + newState := &pb.BeaconState{ + Slot: params.BeaconConfig().SlotsPerEpoch + params.BeaconConfig().GenesisSlot + 1, + LatestAttestations: attestations, + ValidatorBalances: validatorBalances, + ValidatorRegistry: validatorRegistry, + LatestBlockRootHash32S: blockRoots, + LatestCrosslinks: crosslinkRecord, + LatestRandaoMixes: randaoHashes, + LatestIndexRootHash32S: make([][]byte, + params.BeaconConfig().LatestActiveIndexRootsLength), + LatestSlashedBalances: make([]uint64, + params.BeaconConfig().LatestSlashedExitLength), + ValidatorRegistryUpdateEpoch: params.BeaconConfig().GenesisEpoch, + FinalizedEpoch: params.BeaconConfig().GenesisEpoch + 1, + } + + newState, err := state.ProcessEpoch(context.Background(), newState, nil, state.DefaultConfig()) + if err != nil { + t.Errorf("Expected epoch transition to pass processing conditions: %v", err) + } + if newState.ValidatorRegistryUpdateEpoch != params.BeaconConfig().GenesisEpoch { + t.Errorf( + "Expected registry to not have been updated, received update epoch: %v", + newState.ValidatorRegistryUpdateEpoch-params.BeaconConfig().GenesisEpoch, + ) + } + featureconfig.InitFeatureConfig(&featureconfig.FeatureFlagConfig{ + EnableCrosslinks: true, + }) +} + func TestProcessEpoch_InactiveConditions(t *testing.T) { defaultBalance := params.BeaconConfig().MaxDepositAmount @@ -517,7 +588,7 @@ func TestProcessEpoch_InactiveConditions(t *testing.T) { params.BeaconConfig().LatestSlashedExitLength), } - _, err := state.ProcessEpoch(context.Background(), newState, state.DefaultConfig()) + _, err := state.ProcessEpoch(context.Background(), newState, &pb.BeaconBlock{}, state.DefaultConfig()) if err != nil { t.Errorf("Expected epoch transition to pass processing conditions: %v", err) } @@ -536,7 +607,7 @@ func TestProcessEpoch_CantGetBoundaryAttestation(t *testing.T) { 0, newState.Slot-params.BeaconConfig().GenesisSlot, ) - if _, err := state.ProcessEpoch(context.Background(), newState, state.DefaultConfig()); !strings.Contains(err.Error(), want) { + if _, err := state.ProcessEpoch(context.Background(), newState, &pb.BeaconBlock{}, state.DefaultConfig()); !strings.Contains(err.Error(), want) { t.Errorf("Expected: %s, received: %v", want, err) } } @@ -566,7 +637,7 @@ func TestProcessEpoch_CantGetCurrentValidatorIndices(t *testing.T) { } wanted := fmt.Sprintf("wanted participants bitfield length %d, got: %d", 0, 1) - if _, err := state.ProcessEpoch(context.Background(), newState, state.DefaultConfig()); !strings.Contains(err.Error(), wanted) { + if _, err := state.ProcessEpoch(context.Background(), newState, &pb.BeaconBlock{}, state.DefaultConfig()); !strings.Contains(err.Error(), wanted) { t.Errorf("Expected: %s, received: %v", wanted, err) } }