Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Process Crosslink From 0.6 #2460

Merged
merged 17 commits into from
May 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 74 additions & 22 deletions beacon-chain/core/epoch/epoch_processing.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"bytes"
"fmt"
"sort"
"errors"

"github.com/gogo/protobuf/proto"
"github.com/prysmaticlabs/prysm/beacon-chain/core/helpers"
Expand Down Expand Up @@ -48,7 +49,7 @@ func CanProcessEpoch(state *pb.BeaconState) bool {
// epoch processing. This is where a beacon node can justify and finalize a new epoch.
//
// Spec pseudocode definition:
// def process_justification_and_finalization(state: BeaconState) -> None:
// def process_justification_and_finalization(state: BeaconState) -> None:
// if get_current_epoch(state) <= GENESIS_EPOCH + 1:
// return
//
Expand Down Expand Up @@ -157,9 +158,52 @@ func ProcessJustificationFinalization(state *pb.BeaconState, prevAttestedBal uin
return state, nil
}

// ProcessCrosslink processes crosslink and finds the crosslink
// with enough state to make it canonical in state.
//
// Spec pseudocode definition:
// def process_crosslinks(state: BeaconState) -> None:
// state.previous_crosslinks = [c for c in state.current_crosslinks]
// for epoch in (get_previous_epoch(state), get_current_epoch(state)):
// for offset in range(get_epoch_committee_count(state, epoch)):
// shard = (get_epoch_start_shard(state, epoch) + offset) % SHARD_COUNT
// crosslink_committee = get_crosslink_committee(state, epoch, shard)
// winning_crosslink, attesting_indices = get_winning_crosslink_and_attesting_indices(state, epoch, shard)
// if 3 * get_total_balance(state, attesting_indices) >= 2 * get_total_balance(state, crosslink_committee):
// state.current_crosslinks[shard] = winning_crosslink
func ProcessCrosslink(state *pb.BeaconState) (*pb.BeaconState, error) {
state.PreviousCrosslinks = state.CurrentCrosslinks
epochs := []uint64{helpers.PrevEpoch(state), helpers.CurrentEpoch(state)}
for _, e := range epochs {
offset := helpers.EpochCommitteeCount(state, e)
for i := uint64(0); i < offset; i++ {
shard, err := helpers.EpochStartShard(state, e)
if err != nil {
return nil, err
}
committee, err := helpers.CrosslinkCommitteeAtEpoch(state, e, shard)
if err != nil {
return nil, err
}
crosslink, indices, err := WinningCrosslink(state, shard, e)
if err != nil {
return nil, err
}
attestedBalance := helpers.TotalBalance(state, indices)
totalBalance := helpers.TotalBalance(state, committee)
// In order for a crosslink to get included in state, the attesting balance needs to
// be greater than 2/3 of the total balance.
if 3*attestedBalance >= 2*totalBalance {
state.CurrentCrosslinks[shard] = crosslink
}
}
}
return state, nil
}

// ProcessSlashings processes the slashed validators during epoch processing,
//
// def process_slashings(state: BeaconState) -> None:
// def process_slashings(state: BeaconState) -> None:
// current_epoch = get_current_epoch(state)
// active_validator_indices = get_active_validator_indices(state, current_epoch)
// total_balance = get_total_balance(state, active_validator_indices)
Expand Down Expand Up @@ -209,7 +253,7 @@ func ProcessSlashings(state *pb.BeaconState) *pb.BeaconState {
// ProcessFinalUpdates processes the final updates during epoch processing.
//
// Spec pseudocode definition:
// def process_final_updates(state: BeaconState) -> None:
// def process_final_updates(state: BeaconState) -> None:
// current_epoch = get_current_epoch(state)
// next_epoch = current_epoch + 1
// # Reset eth1 data votes
Expand Down Expand Up @@ -314,7 +358,7 @@ func ProcessFinalUpdates(state *pb.BeaconState) (*pb.BeaconState, error) {
// `WinningCrosslink` and `CrosslinkAttestingIndices` for clarity and efficiency.
//
// Spec pseudocode definition:
// def get_winning_crosslink_and_attesting_indices(state: BeaconState, shard: Shard, epoch: Epoch) -> Tuple[Crosslink, List[ValidatorIndex]]:
// def get_winning_crosslink_and_attesting_indices(state: BeaconState, shard: Shard, epoch: Epoch) -> Tuple[Crosslink, List[ValidatorIndex]]:
// shard_attestations = [a for a in get_matching_source_attestations(state, epoch) if a.data.shard == shard]
// shard_crosslinks = [get_crosslink_from_attestation_data(state, a.data) for a in shard_attestations]
// candidate_crosslinks = [
Expand All @@ -332,11 +376,11 @@ func ProcessFinalUpdates(state *pb.BeaconState) (*pb.BeaconState, error) {
// ))
//
// return winning_crosslink, get_unslashed_attesting_indices(state, get_attestations_for(winning_crosslink))
func WinningCrosslink(state *pb.BeaconState, shard uint64, epoch uint64) (*pb.Crosslink, error) {
func WinningCrosslink(state *pb.BeaconState, shard uint64, epoch uint64) (*pb.Crosslink, []uint64, error) {
var shardAtts []*pb.PendingAttestation
matchedAtts, err := MatchAttestations(state, epoch)
if err != nil {
return nil, fmt.Errorf("could not get matching attestations: %v", err)
return nil, nil, fmt.Errorf("could not get matching attestations: %v", err)
}

// Filter out source attestations by shard.
Expand All @@ -347,7 +391,7 @@ func WinningCrosslink(state *pb.BeaconState, shard uint64, epoch uint64) (*pb.Cr
}

// Convert shard attestations to shard crosslinks.
shardCrosslinks := make([]*pb.Crosslink, len(matchedAtts.source))
shardCrosslinks := make([]*pb.Crosslink, len(shardAtts))
for i := 0; i < len(shardCrosslinks); i++ {
shardCrosslinks[i] = CrosslinkFromAttsData(state, shardAtts[i].Data)
}
Expand All @@ -358,7 +402,7 @@ func WinningCrosslink(state *pb.BeaconState, shard uint64, epoch uint64) (*pb.Cr
cFromState := state.CurrentCrosslinks[shard]
h, err := hashutil.HashProto(cFromState)
if err != nil {
return nil, fmt.Errorf("could not hash crosslink from state: %v", err)
return nil, nil, fmt.Errorf("could not hash crosslink from state: %v", err)
}
if proto.Equal(cFromState, c) || bytes.Equal(h[:], c.PreviousCrosslinkRootHash32) {
candidateCrosslinks = append(candidateCrosslinks, c)
Expand All @@ -370,37 +414,45 @@ func WinningCrosslink(state *pb.BeaconState, shard uint64, epoch uint64) (*pb.Cr
Epoch: params.BeaconConfig().GenesisEpoch,
CrosslinkDataRootHash32: params.BeaconConfig().ZeroHash[:],
PreviousCrosslinkRootHash32: params.BeaconConfig().ZeroHash[:],
}, nil
}, nil, nil
}

var crosslinkAtts []*pb.PendingAttestation
var winnerBalance uint64
var winnerCrosslink *pb.Crosslink
// Out of the existing shard crosslinks, pick the one that has the
// most balance staked.
crosslinkAtts = attsForCrosslink(state, candidateCrosslinks[0], shardAtts)
winnerBalance, err = AttestingBalance(state, crosslinkAtts)
winnerCrosslink = candidateCrosslinks[0]
if err != nil {
return nil, nil, err
}

winnerCrosslink = candidateCrosslinks[0]
for _, c := range candidateCrosslinks {
crosslinkAtts := crosslinkAtts[:0]
crosslinkAtts = attsForCrosslink(state, c, shardAtts)
attestingBalance, err := AttestingBalance(state, crosslinkAtts)
if err != nil {
return nil, fmt.Errorf("could not get crosslink's attesting balance: %v", err)
return nil, nil, fmt.Errorf("could not get crosslink's attesting balance: %v", err)
}
if attestingBalance > winnerBalance {
winnerCrosslink = c
}
}
return winnerCrosslink, nil

crosslinkIndices, err := UnslashedAttestingIndices(state, attsForCrosslink(state, winnerCrosslink, shardAtts))
if err != nil {
return nil, nil, errors.New("could not get crosslink indices")
}

return winnerCrosslink, crosslinkIndices, nil
}

// UnslashedAttestingIndices returns all the attesting indices from a list of attestations,
// it sorts the indices and filters out the slashed ones.
//
// Spec pseudocode definition:
// def get_unslashed_attesting_indices(state: BeaconState, attestations: List[PendingAttestation]) -> List[ValidatorIndex]:
// def get_unslashed_attesting_indices(state: BeaconState, attestations: List[PendingAttestation]) -> List[ValidatorIndex]:
// output = set()
// for a in attestations:
// output = output.union(get_attesting_indices(state, a.data, a.aggregation_bitfield))
Expand Down Expand Up @@ -428,7 +480,7 @@ func UnslashedAttestingIndices(state *pb.BeaconState, atts []*pb.PendingAttestat
// AttestingBalance returns the total balance from all the attesting indices.
//
// Spec pseudocode definition:
// def get_attesting_balance(state: BeaconState, attestations: List[PendingAttestation]) -> Gwei:
// def get_attesting_balance(state: BeaconState, attestations: List[PendingAttestation]) -> Gwei:
// return get_total_balance(state, get_unslashed_attesting_indices(state, attestations))
func AttestingBalance(state *pb.BeaconState, atts []*pb.PendingAttestation) (uint64, error) {
indices, err := UnslashedAttestingIndices(state, atts)
Expand All @@ -441,7 +493,7 @@ func AttestingBalance(state *pb.BeaconState, atts []*pb.PendingAttestation) (uin
// EarlistAttestation returns attestation with the earliest inclusion slot.
//
// Spec pseudocode definition:
// def get_earliest_attestation(state: BeaconState, attestations: List[PendingAttestation], index: ValidatorIndex) -> PendingAttestation:
// def get_earliest_attestation(state: BeaconState, attestations: List[PendingAttestation], index: ValidatorIndex) -> PendingAttestation:
// return min([
// a for a in attestations if index in get_attesting_indices(state, a.data, a.aggregation_bitfield)
// ], key=lambda a: a.inclusion_slot)
Expand Down Expand Up @@ -470,17 +522,17 @@ func EarlistAttestation(state *pb.BeaconState, atts []*pb.PendingAttestation, in
// We combined the individual helpers from spec for efficiency and to achieve O(N) run time.
//
// Spec pseudocode definition:
// def get_matching_source_attestations(state: BeaconState, epoch: Epoch) -> List[PendingAttestation]:
// def get_matching_source_attestations(state: BeaconState, epoch: Epoch) -> List[PendingAttestation]:
// assert epoch in (get_current_epoch(state), get_previous_epoch(state))
// return state.current_epoch_attestations if epoch == get_current_epoch(state) else state.previous_epoch_attestations
//
// def get_matching_target_attestations(state: BeaconState, epoch: Epoch) -> List[PendingAttestation]:
// def get_matching_target_attestations(state: BeaconState, epoch: Epoch) -> List[PendingAttestation]:
// return [
// a for a in get_matching_source_attestations(state, epoch)
// if a.data.target_root == get_block_root(state, epoch)
// ]
//
// def get_matching_head_attestations(state: BeaconState, epoch: Epoch) -> List[PendingAttestation]:
// def get_matching_head_attestations(state: BeaconState, epoch: Epoch) -> List[PendingAttestation]:
// return [
// a for a in get_matching_source_attestations(state, epoch)
// if a.data.beacon_block_root == get_block_root_at_slot(state, a.data.slot)
Expand Down Expand Up @@ -539,7 +591,7 @@ func MatchAttestations(state *pb.BeaconState, epoch uint64) (*MatchedAttestation
// CrosslinkFromAttsData returns a constructed crosslink from attestation data.
//
// Spec pseudocode definition:
// def get_crosslink_from_attestation_data(state: BeaconState, data: AttestationData) -> Crosslink:
// def get_crosslink_from_attestation_data(state: BeaconState, data: AttestationData) -> Crosslink:
// return Crosslink(
// epoch=min(slot_to_epoch(data.slot), state.current_crosslinks[data.shard].epoch + MAX_CROSSLINK_EPOCHS),
// previous_crosslink_root=data.previous_crosslink_root,
Expand Down Expand Up @@ -567,7 +619,7 @@ func CrosslinkAttestingIndices(state *pb.BeaconState, crosslink *pb.Crosslink, a
// individual validator's base reward quotient.
//
// Spec pseudocode definition:
// def get_base_reward(state: BeaconState, index: ValidatorIndex) -> Gwei:
// def get_base_reward(state: BeaconState, index: ValidatorIndex) -> Gwei:
// adjusted_quotient = integer_squareroot(get_total_active_balance(state)) // BASE_REWARD_QUOTIENT
// if adjusted_quotient == 0:
// return 0
Expand Down Expand Up @@ -596,7 +648,7 @@ func attsForCrosslink(state *pb.BeaconState, crosslink *pb.Crosslink, atts []*pb
// totalActiveBalance returns the combined balances of all the active validators.
//
// Spec pseudocode definition:
// def get_total_active_balance(state: BeaconState) -> Gwei:
// def get_total_active_balance(state: BeaconState) -> Gwei:
// return get_total_balance(state, get_active_validator_indices(state, get_current_epoch(state)))
func totalActiveBalance(state *pb.BeaconState) uint64 {
return helpers.TotalBalance(state, helpers.ActiveValidatorIndices(state, helpers.CurrentEpoch(state)))
Expand Down
116 changes: 112 additions & 4 deletions beacon-chain/core/epoch/epoch_processing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ func TestCrosslinkAttestingIndices_CanGetIndices(t *testing.T) {
func TestWinningCrosslink_CantGetMatchingAtts(t *testing.T) {
wanted := fmt.Sprintf("could not get matching attestations: input epoch: %d != current epoch: %d or previous epoch: %d",
100, params.BeaconConfig().GenesisEpoch, params.BeaconConfig().GenesisEpoch)
_, err := WinningCrosslink(&pb.BeaconState{Slot: params.BeaconConfig().GenesisSlot}, 0, 100)
_, _, err := WinningCrosslink(&pb.BeaconState{Slot: params.BeaconConfig().GenesisSlot}, 0, 100)
if err.Error() != wanted {
t.Fatal(err)
}
Expand All @@ -548,10 +548,13 @@ func TestWinningCrosslink_ReturnGensisCrosslink(t *testing.T) {
PreviousCrosslinkRootHash32: params.BeaconConfig().ZeroHash[:],
}

crosslink, err := WinningCrosslink(state, 0, ge)
crosslink, indices, err := WinningCrosslink(state, 0, ge)
if err != nil {
t.Fatal(err)
}
if len(indices) != 0 {
t.Errorf("gensis crosslink indices is not 0, got: %d", len(indices))
}
if !reflect.DeepEqual(crosslink, gCrosslink) {
t.Errorf("Did not get genesis crosslink, got: %v", crosslink)
}
Expand Down Expand Up @@ -601,17 +604,122 @@ func TestWinningCrosslink_CanGetWinningRoot(t *testing.T) {
LatestActiveIndexRoots: make([][]byte, params.BeaconConfig().LatestActiveIndexRootsLength),
}

winner, err := WinningCrosslink(state, 3, ge)
winner, indices, err := WinningCrosslink(state, 0, ge)
if err != nil {
t.Fatal(err)
}

if len(indices) != 0 {
t.Errorf("gensis crosslink indices is not 0, got: %d", len(indices))
}
want := &pb.Crosslink{Epoch: ge, CrosslinkDataRootHash32: []byte{'B'}}
if !reflect.DeepEqual(winner, want) {
t.Errorf("Did not get genesis crosslink, got: %v", winner)
}
}

func TestProcessCrosslink_NoUpdate(t *testing.T) {
validators := make([]*pb.Validator, params.BeaconConfig().DepositsForChainStart)
balances := make([]uint64, params.BeaconConfig().DepositsForChainStart)
for i := 0; i < len(validators); i++ {
validators[i] = &pb.Validator{
ExitEpoch: params.BeaconConfig().FarFutureEpoch,
EffectiveBalance: params.BeaconConfig().MaxDepositAmount,
}
balances[i] = params.BeaconConfig().MaxDepositAmount
}
blockRoots := make([][]byte, 128)
for i := 0; i < len(blockRoots); i++ {
blockRoots[i] = []byte{byte(i + 1)}
}
oldCrosslink := &pb.Crosslink{
Epoch: params.BeaconConfig().GenesisEpoch,
CrosslinkDataRootHash32: []byte{'A'},
}
var crosslinks []*pb.Crosslink
for i := uint64(0); i < params.BeaconConfig().ShardCount; i++ {
crosslinks = append(crosslinks, &pb.Crosslink{
Epoch: params.BeaconConfig().GenesisEpoch,
CrosslinkDataRootHash32: []byte{'A'},
})
}
state := &pb.BeaconState{
Slot: params.BeaconConfig().GenesisSlot + params.BeaconConfig().SlotsPerEpoch + 1,
ValidatorRegistry: validators,
Balances: balances,
LatestBlockRoots: blockRoots,
LatestRandaoMixes: make([][]byte, params.BeaconConfig().LatestRandaoMixesLength),
LatestActiveIndexRoots: make([][]byte, params.BeaconConfig().LatestActiveIndexRootsLength),
CurrentCrosslinks: crosslinks,
}
newState, err := ProcessCrosslink(state)
if err != nil {
t.Fatal(err)
}

// Since there has been no attestation, crosslink stayed the same.
if !reflect.DeepEqual(oldCrosslink, newState.CurrentCrosslinks[0]) {
t.Errorf("Did not get correct crosslink back")
}
}

func TestProcessCrosslink_SuccessfulUpdate(t *testing.T) {
e := params.BeaconConfig().SlotsPerEpoch
gs := params.BeaconConfig().GenesisSlot
ge := params.BeaconConfig().GenesisEpoch

validators := make([]*pb.Validator, params.BeaconConfig().DepositsForChainStart/8)
balances := make([]uint64, params.BeaconConfig().DepositsForChainStart/8)
for i := 0; i < len(validators); i++ {
validators[i] = &pb.Validator{
ExitEpoch: params.BeaconConfig().FarFutureEpoch,
EffectiveBalance: params.BeaconConfig().MaxDepositAmount,
}
balances[i] = params.BeaconConfig().MaxDepositAmount
}
blockRoots := make([][]byte, 128)
for i := 0; i < len(blockRoots); i++ {
blockRoots[i] = []byte{byte(i + 1)}
}

crosslinks := make([]*pb.Crosslink, params.BeaconConfig().ShardCount)
for i := uint64(0); i < params.BeaconConfig().ShardCount; i++ {
crosslinks[i] = &pb.Crosslink{
Epoch: ge,
CrosslinkDataRootHash32: []byte{'B'},
}
}
var atts []*pb.PendingAttestation
for s := uint64(0); s < params.BeaconConfig().ShardCount; s++ {
atts = append(atts, &pb.PendingAttestation{
Data: &pb.AttestationData{
Slot: gs + 1 + (s % e),
Shard: s,
CrosslinkDataRoot: []byte{'B'},
TargetEpoch: params.BeaconConfig().GenesisEpoch,
},
AggregationBitfield: []byte{0xC0, 0xC0, 0xC0, 0xC0},
})
}
state := &pb.BeaconState{
Slot: gs + e + 2,
ValidatorRegistry: validators,
PreviousEpochAttestations: atts,
Balances: balances,
LatestBlockRoots: blockRoots,
CurrentCrosslinks: crosslinks,
LatestRandaoMixes: make([][]byte, params.BeaconConfig().LatestRandaoMixesLength),
LatestActiveIndexRoots: make([][]byte, params.BeaconConfig().LatestActiveIndexRootsLength),
}
newState, err := ProcessCrosslink(state)
if err != nil {
t.Fatal(err)
}

if !reflect.DeepEqual(crosslinks[0], newState.CurrentCrosslinks[0]) {
t.Errorf("Crosslink is not the same")
}
}

func TestBaseReward_AccurateRewards(t *testing.T) {
tests := []struct {
a uint64
Expand Down
Loading