Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactored transitions methods in Erigon-CL #6714

Merged
merged 1 commit into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 132 additions & 1 deletion cmd/erigon-cl/core/state/accessors.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
package state

import (
"crypto/sha256"
"encoding/binary"
"fmt"

libcommon "github.com/ledgerwatch/erigon-lib/common"
"github.com/ledgerwatch/erigon/cl/clparams"
"github.com/ledgerwatch/erigon/cl/cltypes"
"github.com/ledgerwatch/erigon/cl/fork"
"github.com/ledgerwatch/erigon/cl/utils"
)

// GetActiveValidatorsIndices returns the list of validator indices active for the given epoch.
Expand All @@ -18,9 +23,14 @@ func (b *BeaconState) GetActiveValidatorsIndices(epoch uint64) (indicies []uint6
return
}

// GetEpochAtSlot gives the epoch for a certain slot
func (b *BeaconState) GetEpochAtSlot(slot uint64) uint64 {
return slot / b.beaconConfig.SlotsPerEpoch
}

// Epoch returns current epoch.
func (b *BeaconState) Epoch() uint64 {
return b.slot / b.beaconConfig.SlotsPerEpoch // Return current state epoch
return b.GetEpochAtSlot(b.slot)
}

// PreviousEpoch returns previous epoch.
Expand Down Expand Up @@ -96,3 +106,124 @@ func (b *BeaconState) GetBlockRootAtSlot(slot uint64) (libcommon.Hash, error) {
}
return b.blockRoots[slot%b.beaconConfig.SlotsPerHistoricalRoot], nil
}

func (b *BeaconState) GetDomain(domainType [4]byte, epoch uint64) ([]byte, error) {
if epoch == 0 {
epoch = b.Epoch()
}
var forkVersion [4]byte
if epoch < b.fork.Epoch {
forkVersion = b.fork.PreviousVersion
} else {
forkVersion = b.fork.CurrentVersion
}
return fork.ComputeDomain(domainType[:], forkVersion, b.genesisValidatorsRoot)
}

func (b *BeaconState) ComputeShuffledIndex(ind, ind_count uint64, seed [32]byte) (uint64, error) {
if ind >= ind_count {
return 0, fmt.Errorf("index=%d must be less than the index count=%d", ind, ind_count)
}

for i := uint64(0); i < b.beaconConfig.ShuffleRoundCount; i++ {
// Construct first hash input.
input := append(seed[:], byte(i))
hashedInput := utils.Keccak256(input)

// Read hash value.
hashValue := binary.LittleEndian.Uint64(hashedInput[:8])

// Caclulate pivot and flip.
pivot := hashValue % ind_count
flip := (pivot + ind_count - ind) % ind_count

// No uint64 max function in go standard library.
position := ind
if flip > ind {
position = flip
}

// Construct the second hash input.
positionByteArray := make([]byte, 4)
binary.LittleEndian.PutUint32(positionByteArray, uint32(position>>8))
input2 := append(seed[:], byte(i))
input2 = append(input2, positionByteArray...)

hashedInput2 := utils.Keccak256(input2)
// Read hash value.
byteVal := hashedInput2[(position%256)/8]
bitVal := (byteVal >> (position % 8)) % 2
if bitVal == 1 {
ind = flip
}
}
return ind, nil
}

func (b *BeaconState) ComputeProposerIndex(indices []uint64, seed [32]byte) (uint64, error) {
if len(indices) == 0 {
return 0, fmt.Errorf("must have >0 indices")
}
maxRandomByte := uint64(1<<8 - 1)
i := uint64(0)
total := uint64(len(indices))
buf := make([]byte, 8)
for {
shuffled, err := b.ComputeShuffledIndex(i%total, total, seed)
if err != nil {
return 0, err
}
candidateIndex := indices[shuffled]
if candidateIndex >= uint64(len(b.validators)) {
return 0, fmt.Errorf("candidate index out of range: %d for validator set of length: %d", candidateIndex, len(b.validators))
}
binary.LittleEndian.PutUint64(buf, i/32)
input := append(seed[:], buf...)
randomByte := uint64(utils.Keccak256(input)[i%32])
effectiveBalance := b.validators[candidateIndex].EffectiveBalance
if effectiveBalance*maxRandomByte >= clparams.MainnetBeaconConfig.MaxEffectiveBalance*randomByte {
return candidateIndex, nil
}
i += 1
}
}

func (b *BeaconState) GetRandaoMixes(epoch uint64) [32]byte {
return b.randaoMixes[epoch%b.beaconConfig.EpochsPerHistoricalVector]
}

func (b *BeaconState) GetBeaconProposerIndex() (uint64, error) {
epoch := b.Epoch()

hash := sha256.New()
// Input for the seed hash.
input := b.GetSeed(epoch, clparams.MainnetBeaconConfig.DomainBeaconProposer)
slotByteArray := make([]byte, 8)
binary.LittleEndian.PutUint64(slotByteArray, b.Slot())

// Add slot to the end of the input.
inputWithSlot := append(input, slotByteArray...)

// Calculate the hash.
hash.Write(inputWithSlot)
seed := hash.Sum(nil)

indices := b.GetActiveValidatorsIndices(epoch)

// Write the seed to an array.
seedArray := [32]byte{}
copy(seedArray[:], seed)

return b.ComputeProposerIndex(indices, seedArray)
}

func (b *BeaconState) GetSeed(epoch uint64, domain [4]byte) []byte {
mix := b.GetRandaoMixes(epoch + b.beaconConfig.EpochsPerHistoricalVector - b.beaconConfig.MinSeedLookahead - 1)
epochByteArray := make([]byte, 8)
binary.LittleEndian.PutUint64(epochByteArray, epoch)
input := append(domain[:], epochByteArray...)
input = append(input, mix[:]...)
hash := sha256.New()
hash.Write(input)
return hash.Sum(nil)
}
184 changes: 184 additions & 0 deletions cmd/erigon-cl/core/state/accessors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,33 @@ import (
"testing"

"github.com/ledgerwatch/erigon-lib/common"
"github.com/ledgerwatch/erigon/cl/clparams"
"github.com/ledgerwatch/erigon/cl/cltypes"
"github.com/ledgerwatch/erigon/cmd/erigon-cl/core/state"
"github.com/stretchr/testify/require"
)

func getTestState(t *testing.T) *state.BeaconState {
numVals := 2048
validators := make([]*cltypes.Validator, numVals)
for i := 0; i < numVals; i++ {
validators[i] = &cltypes.Validator{
ActivationEpoch: 0,
ExitEpoch: 10000,
}
}
b := state.GetEmptyBeaconState()
b.SetValidators(validators)
b.SetSlot(19)
b.SetLatestBlockHeader(&cltypes.BeaconBlockHeader{Slot: 18})
b.SetFork(&cltypes.Fork{
Epoch: 0,
PreviousVersion: [4]byte{0, 1, 2, 3},
CurrentVersion: [4]byte{3, 2, 1, 0},
})
return b
}

func TestActiveValidatorIndices(t *testing.T) {
epoch := uint64(2)
testState := state.GetEmptyBeaconState()
Expand Down Expand Up @@ -48,3 +70,165 @@ func TestGetBlockRoot(t *testing.T) {
require.NoError(t, err)
require.Equal(t, retrieved, root)
}

func TestGetBeaconProposerIndex(t *testing.T) {
state := getTestState(t)
numVals := 2048
validators := make([]*cltypes.Validator, numVals)
for i := 0; i < numVals; i++ {
validators[i] = &cltypes.Validator{
ActivationEpoch: 0,
ExitEpoch: 10000,
}
}
testCases := []struct {
description string
slot uint64
expected uint64
}{
{
description: "slot1",
slot: 1,
expected: 2039,
},
{
description: "slot5",
slot: 5,
expected: 1895,
},
{
description: "slot19",
slot: 19,
expected: 1947,
},
{
description: "slot30",
slot: 30,
expected: 369,
},
{
description: "slot43",
slot: 43,
expected: 464,
},
}

for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
state.SetSlot(tc.slot)
got, err := state.GetBeaconProposerIndex()
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if got != tc.expected {
t.Errorf("unexpected result: got %d, want %d", got, tc.expected)
}
})
}
}

func TestComputeShuffledIndex(t *testing.T) {
testCases := []struct {
description string
startInds []uint64
expectedInds []uint64
seed [32]byte
}{
{
description: "success",
startInds: []uint64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
expectedInds: []uint64{0, 9, 8, 4, 6, 7, 3, 1, 2, 5},
seed: [32]byte{1, 128, 12},
},
}

for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
for i, val := range tc.startInds {
state := state.New(&clparams.MainnetBeaconConfig)
got, err := state.ComputeShuffledIndex(val, uint64(len(tc.startInds)), tc.seed)
// Non-failure case.
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if got != tc.expectedInds[i] {
t.Errorf("unexpected result: got %d, want %d", got, tc.expectedInds[i])
}
}
})
}
}

func generateBeaconStateWithValidators(n int) *state.BeaconState {
b := state.GetEmptyBeaconState()
for i := 0; i < n; i++ {
b.AddValidator(&cltypes.Validator{EffectiveBalance: clparams.MainnetBeaconConfig.MaxEffectiveBalance})
}
return b
}

func TestComputeProposerIndex(t *testing.T) {
seed := [32]byte{}
copy(seed[:], []byte("seed"))
testCases := []struct {
description string
state *state.BeaconState
indices []uint64
seed [32]byte
expected uint64
wantErr bool
}{
{
description: "success",
state: generateBeaconStateWithValidators(5),
indices: []uint64{0, 1, 2, 3, 4},
seed: seed,
expected: 2,
},
{
description: "single_active_index",
state: generateBeaconStateWithValidators(5),
indices: []uint64{3},
seed: seed,
expected: 3,
},
{
description: "second_half_active",
state: generateBeaconStateWithValidators(10),
indices: []uint64{5, 6, 7, 8, 9},
seed: seed,
expected: 7,
},
{
description: "zero_active_indices",
indices: []uint64{},
seed: seed,
wantErr: true,
},
{
description: "active_index_out_of_range",
indices: []uint64{100},
state: generateBeaconStateWithValidators(1),
seed: seed,
wantErr: true,
},
}

for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
got, err := tc.state.ComputeProposerIndex(tc.indices, tc.seed)
if tc.wantErr {
if err == nil {
t.Errorf("unexpected success, wanted error")
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if got != tc.expected {
t.Errorf("unexpected result: got %d, want %d", got, tc.expected)
}
})
}
}
4 changes: 4 additions & 0 deletions cmd/erigon-cl/core/state/getters.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ func (b *BeaconState) Balances() []uint64 {
return b.balances
}

func (b *BeaconState) ValidatorBalance(index int) uint64 {
return b.balances[index]
}

func (b *BeaconState) RandaoMixes() [randoMixesLength]libcommon.Hash {
return b.randaoMixes
}
Expand Down
Loading