diff --git a/protocol/syncer.go b/protocol/syncer.go index 258baebbd9..214655e587 100644 --- a/protocol/syncer.go +++ b/protocol/syncer.go @@ -295,15 +295,20 @@ func (s *Syncer) updatePeerStatus(peerID peer.ID, status *Status) { // Broadcast broadcasts a block to all peers func (s *Syncer) Broadcast(b *types.Block) { - // diff is number in ibft - diff := new(big.Int).SetUint64(b.Header.Difficulty) + // Get the chain difficulty associated with block + td, ok := s.blockchain.GetTD(b.Hash()) + if !ok { + // not supposed to happen + s.logger.Error("total difficulty not found", "block number", b.Number()) + return + } // broadcast the new block to all the peers req := &proto.NotifyReq{ Status: &proto.V1Status{ Hash: b.Hash().String(), Number: b.Number(), - Difficulty: diff.String(), + Difficulty: td.String(), }, Raw: &any.Any{ Value: b.MarshalRLP(), @@ -385,7 +390,6 @@ func (s *Syncer) BestPeer() *syncPeer { } curDiff := s.blockchain.CurrentTD() - if bestTd.Cmp(curDiff) <= 0 { return nil } diff --git a/protocol/syncer_test.go b/protocol/syncer_test.go index 5bca90f760..7f589f89f1 100644 --- a/protocol/syncer_test.go +++ b/protocol/syncer_test.go @@ -93,25 +93,29 @@ func TestDeletePeer(t *testing.T) { func TestBroadcast(t *testing.T) { tests := []struct { - name string - chain blockchainShim - peerChain blockchainShim - numNewBlocks int + name string + syncerHeaders []*types.Header + peerHeaders []*types.Header + numNewBlocks int }{ { - name: "syncer should receive new block in peer", - chain: NewRandomChain(t, 5), - peerChain: NewRandomChain(t, 10), - numNewBlocks: 5, + name: "syncer should receive new block in peer", + syncerHeaders: blockchain.NewTestHeaderChainWithSeed(nil, 5, 0), + peerHeaders: blockchain.NewTestHeaderChainWithSeed(nil, 10, 0), + numNewBlocks: 5, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - syncer, peerSyncers := SetupSyncerNetwork(t, tt.chain, []blockchainShim{tt.peerChain}) + chain, peerChain := NewMockBlockchain(tt.syncerHeaders), NewMockBlockchain(tt.peerHeaders) + syncer, peerSyncers := SetupSyncerNetwork(t, chain, []blockchainShim{peerChain}) peerSyncer := peerSyncers[0] newBlocks := GenerateNewBlocks(t, peerSyncer.blockchain, tt.numNewBlocks) + + assert.NoError(t, peerSyncer.blockchain.WriteBlocks(newBlocks)) + for _, newBlock := range newBlocks { peerSyncer.Broadcast(newBlock) } @@ -271,6 +275,9 @@ func TestWatchSyncWithPeer(t *testing.T) { peerSyncer := peerSyncers[0] newBlocks := GenerateNewBlocks(t, peerChain, tt.numNewBlocks) + + assert.NoError(t, peerSyncer.blockchain.WriteBlocks(newBlocks)) + for _, b := range newBlocks { peerSyncer.Broadcast(b) } diff --git a/protocol/testing.go b/protocol/testing.go index 7d79442513..b08df1e75d 100644 --- a/protocol/testing.go +++ b/protocol/testing.go @@ -115,7 +115,7 @@ func SetupSyncerNetwork(t *testing.T, chain blockchainShim, peerChains []blockch network.MultiJoin(t, syncer.server, peerSyncers[idx].server) } WaitUntilPeerConnected(t, syncer, len(peerChains), 10*time.Second) - return + return syncer, peerSyncers } // GenerateNewBlocks returns new blocks from latest block of given chain @@ -166,10 +166,14 @@ func GetCurrentStatus(b blockchainShim) *Status { // HeaderToStatus converts given header to Status func HeaderToStatus(h *types.Header) *Status { + var td uint64 = 0 + for i := uint64(1); i <= h.Difficulty; i++ { + td = td + i + } return &Status{ Hash: h.Hash, Number: h.Number, - Difficulty: big.NewInt(0).SetUint64(h.Difficulty), + Difficulty: big.NewInt(0).SetUint64(td), } } @@ -207,13 +211,18 @@ func (b *mockBlockchain) CurrentTD() *big.Int { if current == nil { return nil } - return new(big.Int).SetUint64(current.Difficulty) + + td, _ := b.GetTD(current.Hash) + return td } func (b *mockBlockchain) GetTD(hash types.Hash) (*big.Int, bool) { + var td uint64 = 0 for _, b := range b.blocks { + td = td + b.Header.Difficulty + if b.Header.Hash == hash { - return new(big.Int).SetUint64(b.Header.Difficulty), true + return big.NewInt(0).SetUint64(td), true } } return nil, false @@ -264,8 +273,9 @@ func NewMockSubscription() *mockSubscription { func (s *mockSubscription) AppendBlocks(blocks []*types.Block) { for _, b := range blocks { + status := HeaderToStatus(b.Header) s.eventCh <- &blockchain.Event{ - Difficulty: new(big.Int).SetUint64(b.Header.Difficulty), + Difficulty: status.Difficulty, NewChain: []*types.Header{b.Header}, } }