diff --git a/dot/sync/configuration.go b/dot/sync/configuration.go index e144a87cbc..646c77ce87 100644 --- a/dot/sync/configuration.go +++ b/dot/sync/configuration.go @@ -17,7 +17,7 @@ func WithStrategies(currentStrategy, defaultStrategy Strategy) ServiceConfig { func WithNetwork(net Network) ServiceConfig { return func(svc *SyncService) { svc.network = net - svc.workerPool = newSyncWorkerPool(net) + //svc.workerPool = newSyncWorkerPool(net) } } diff --git a/dot/sync/fullsync.go b/dot/sync/fullsync.go index 79db8b4a29..789743a5d9 100644 --- a/dot/sync/fullsync.go +++ b/dot/sync/fullsync.go @@ -19,7 +19,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" ) -const defaultNumOfTasks = 3 +const defaultNumOfTasks = 10 var _ Strategy = (*FullSyncStrategy)(nil) @@ -86,7 +86,7 @@ func NewFullSyncStrategy(cfg *FullSyncConfig) *FullSyncStrategy { } } -func (f *FullSyncStrategy) NextActions() ([]*SyncTask, error) { +func (f *FullSyncStrategy) NextActions() ([]Task, error) { f.startedAt = time.Now() f.syncedBlocks = 0 @@ -129,12 +129,11 @@ func (f *FullSyncStrategy) NextActions() ([]*SyncTask, error) { return f.createTasks(reqsFromQueue), nil } -func (f *FullSyncStrategy) createTasks(requests []*messages.BlockRequestMessage) []*SyncTask { - tasks := make([]*SyncTask, 0, len(requests)) +func (f *FullSyncStrategy) createTasks(requests []*messages.BlockRequestMessage) []Task { + tasks := make([]Task, 0, len(requests)) for _, req := range requests { - tasks = append(tasks, &SyncTask{ + tasks = append(tasks, &syncTask{ request: req, - response: &messages.BlockResponseMessage{}, requestMaker: f.reqMaker, }) } @@ -146,126 +145,139 @@ func (f *FullSyncStrategy) createTasks(requests []*messages.BlockRequestMessage) // or complete an incomplete block or is part of a disjoint block set which will // as a result it returns the if the strategy is finished, the peer reputations to change, // peers to block/ban, or an error. FullSyncStrategy is intended to run as long as the node lives. -func (f *FullSyncStrategy) Process(results []*SyncTaskResult) ( +func (f *FullSyncStrategy) Process(results <-chan TaskResult) ( isFinished bool, reputations []Change, bans []peer.ID, err error) { - repChanges, peersToIgnore, validResp := validateResults(results, f.badBlocks) - logger.Debugf("evaluating %d task results, %d valid responses", len(results), len(validResp)) - var highestFinalized *types.Header - highestFinalized, err = f.blockState.GetHighestFinalisedHeader() + highestFinalized, err := f.blockState.GetHighestFinalisedHeader() if err != nil { return false, nil, nil, fmt.Errorf("getting highest finalized header") } - readyBlocks := make([][]*types.BlockData, 0, len(validResp)) - for _, reqRespData := range validResp { - // if Gossamer requested the header, then the response data should contains + // This is safe as long as we are the only goroutine reading from the channel. + for len(results) > 0 { + readyBlocks := make([][]*types.BlockData, 0) + result := <-results + repChange, ignorePeer, validResp := validateResult(result, f.badBlocks) + + if repChange != nil { + reputations = append(reputations, *repChange) + } + + if ignorePeer { + bans = append(bans, result.Who) + } + + if validResp == nil || len(validResp.responseData) == 0 { + continue + } + + // if Gossamer requested the header, then the response data should contain // the full blocks to be imported. If Gossamer didn't request the header, // then the response should only contain the missing parts that will complete // the unreadyBlocks and then with the blocks completed we should be able to import them - if reqRespData.req.RequestField(messages.RequestedDataHeader) { - updatedFragment, ok := f.unreadyBlocks.updateDisjointFragments(reqRespData.responseData) + if validResp.req.RequestField(messages.RequestedDataHeader) { + updatedFragment, ok := f.unreadyBlocks.updateDisjointFragments(validResp.responseData) if ok { validBlocks := validBlocksUnderFragment(highestFinalized.Number, updatedFragment) if len(validBlocks) > 0 { readyBlocks = append(readyBlocks, validBlocks) } } else { - readyBlocks = append(readyBlocks, reqRespData.responseData) + readyBlocks = append(readyBlocks, validResp.responseData) } - continue - } - - completedBlocks := f.unreadyBlocks.updateIncompleteBlocks(reqRespData.responseData) - readyBlocks = append(readyBlocks, completedBlocks) - } - - // disjoint fragments are pieces of the chain that could not be imported right now - // because is blocks too far ahead or blocks that belongs to forks - sortFragmentsOfChain(readyBlocks) - orderedFragments := mergeFragmentsOfChain(readyBlocks) - - nextBlocksToImport := make([]*types.BlockData, 0) - disjointFragments := make([][]*types.BlockData, 0) - - for _, fragment := range orderedFragments { - ok, err := f.blockState.HasHeader(fragment[0].Header.ParentHash) - if err != nil && !errors.Is(err, database.ErrNotFound) { - return false, nil, nil, fmt.Errorf("checking block parent header: %w", err) + completedBlocks := f.unreadyBlocks.updateIncompleteBlocks(validResp.responseData) + if len(completedBlocks) > 0 { + readyBlocks = append(readyBlocks, completedBlocks) + } } - if ok { - nextBlocksToImport = append(nextBlocksToImport, fragment...) - continue - } + // disjoint fragments are pieces of the chain that could not be imported right now + // because is blocks too far ahead or blocks that belongs to forks + sortFragmentsOfChain(readyBlocks) + orderedFragments := mergeFragmentsOfChain(readyBlocks) - disjointFragments = append(disjointFragments, fragment) - } + nextBlocksToImport := make([]*types.BlockData, 0) + disjointFragments := make([][]*types.BlockData, 0) - // this loop goal is to import ready blocks as well as update the highestFinalized header - for len(nextBlocksToImport) > 0 || len(disjointFragments) > 0 { - for _, blockToImport := range nextBlocksToImport { - imported, err := f.blockImporter.importBlock(blockToImport, networkInitialSync) - if err != nil { - return false, nil, nil, fmt.Errorf("while handling ready block: %w", err) + for _, fragment := range orderedFragments { + ok, err := f.blockState.HasHeader(fragment[0].Header.ParentHash) + if err != nil && !errors.Is(err, database.ErrNotFound) { + return false, nil, nil, fmt.Errorf("checking block parent header: %w", err) } - if imported { - f.syncedBlocks += 1 + if ok { + nextBlocksToImport = append(nextBlocksToImport, fragment...) + continue } - } - nextBlocksToImport = make([]*types.BlockData, 0) - highestFinalized, err = f.blockState.GetHighestFinalisedHeader() - if err != nil { - return false, nil, nil, fmt.Errorf("getting highest finalized header") + disjointFragments = append(disjointFragments, fragment) } - // check if blocks from the disjoint set can be imported on their on forks - // given that fragment contains chains and these chains contains blocks - // check if the first block in the chain contains a parent known by us - for _, fragment := range disjointFragments { - validFragment := validBlocksUnderFragment(highestFinalized.Number, fragment) - if len(validFragment) == 0 { - continue + // this loop goal is to import ready blocks as well as update the highestFinalized header + for len(nextBlocksToImport) > 0 || len(disjointFragments) > 0 { + for _, blockToImport := range nextBlocksToImport { + imported, err := f.blockImporter.importBlock(blockToImport, networkInitialSync) + if err != nil { + return false, nil, nil, fmt.Errorf("while handling ready block: %w", err) + } + + if imported { + f.syncedBlocks += 1 + } } - ok, err := f.blockState.HasHeader(validFragment[0].Header.ParentHash) - if err != nil && !errors.Is(err, database.ErrNotFound) { - return false, nil, nil, err + nextBlocksToImport = make([]*types.BlockData, 0) + highestFinalized, err = f.blockState.GetHighestFinalisedHeader() + if err != nil { + return false, nil, nil, fmt.Errorf("getting highest finalized header") } - if !ok { - // if the parent of this valid fragment is behind our latest finalized number - // then we can discard the whole fragment since it is a invalid fork - if (validFragment[0].Header.Number - 1) <= highestFinalized.Number { + // check if blocks from the disjoint set can be imported or they're on forks + // given that fragment contains chains and these chains contains blocks + // check if the first block in the chain contains a parent known by us + for _, fragment := range disjointFragments { + validFragment := validBlocksUnderFragment(highestFinalized.Number, fragment) + if len(validFragment) == 0 { continue } - logger.Infof("starting an acestor search from %s parent of #%d (%s)", - validFragment[0].Header.ParentHash, - validFragment[0].Header.Number, - validFragment[0].Header.Hash(), - ) - - f.unreadyBlocks.newDisjointFragment(validFragment) - request := messages.NewBlockRequest( - *messages.NewFromBlock(validFragment[0].Header.ParentHash), - messages.MaxBlocksInResponse, - messages.BootstrapRequestData, messages.Descending) - f.requestQueue.PushBack(request) - } else { - // inserting them in the queue to be processed after the main chain - nextBlocksToImport = append(nextBlocksToImport, validFragment...) + ok, err := f.blockState.HasHeader(validFragment[0].Header.ParentHash) + if err != nil && !errors.Is(err, database.ErrNotFound) { + return false, nil, nil, err + } + + if !ok { + // if the parent of this valid fragment is behind our latest finalized number + // then we can discard the whole fragment since it is a invalid fork + if (validFragment[0].Header.Number - 1) <= highestFinalized.Number { + continue + } + + logger.Infof("starting an ancestor search from %s parent of #%d (%s)", + validFragment[0].Header.ParentHash, + validFragment[0].Header.Number, + validFragment[0].Header.Hash(), + ) + + f.unreadyBlocks.newDisjointFragment(validFragment) + request := messages.NewBlockRequest( + *messages.NewFromBlock(validFragment[0].Header.ParentHash), + messages.MaxBlocksInResponse, + messages.BootstrapRequestData, messages.Descending) + f.requestQueue.PushBack(request) + } else { + // inserting them in the queue to be processed in the next loop iteration + nextBlocksToImport = append(nextBlocksToImport, validFragment...) + } } - } - disjointFragments = nil + disjointFragments = nil + } } f.unreadyBlocks.removeIrrelevantFragments(highestFinalized.Number) - return false, repChanges, peersToIgnore, nil + return false, reputations, bans, nil } func (f *FullSyncStrategy) ShowMetrics() { @@ -395,85 +407,79 @@ type RequestResponseData struct { responseData []*types.BlockData } -func validateResults(results []*SyncTaskResult, badBlocks []string) (repChanges []Change, - peersToBlock []peer.ID, validRes []RequestResponseData) { - - repChanges = make([]Change, 0) - peersToBlock = make([]peer.ID, 0) - validRes = make([]RequestResponseData, 0, len(results)) - -resultLoop: - for _, result := range results { - request := result.request.(*messages.BlockRequestMessage) +func validateResult(result TaskResult, badBlocks []string) (repChange *Change, + blockPeer bool, validRes *RequestResponseData) { - if !result.completed { - continue - } - - response := result.response.(*messages.BlockResponseMessage) - if request.Direction == messages.Descending { - // reverse blocks before pre-validating and placing in ready queue - slices.Reverse(response.BlockData) - } + if !result.Completed { + return + } - err := validateResponseFields(request, response.BlockData) - if err != nil { - logger.Warnf("validating fields: %s", err) - // TODO: check the reputation change for nil body in response - // and nil justification in response - if errors.Is(err, errNilHeaderInResponse) { - repChanges = append(repChanges, Change{ - who: result.who, - rep: peerset.ReputationChange{ - Value: peerset.IncompleteHeaderValue, - Reason: peerset.IncompleteHeaderReason, - }, - }) - } + task, ok := result.Task.(*syncTask) + if !ok { + logger.Warnf("skipping unexpected task type in TaskResult: %T", result.Task) + return + } - continue - } + request := task.request.(*messages.BlockRequestMessage) + response := result.Result.(*messages.BlockResponseMessage) + if request.Direction == messages.Descending { + // reverse blocks before pre-validating and placing in ready queue + slices.Reverse(response.BlockData) + } - // only check if the responses forms a chain if the response contains the headers - // of each block, othewise the response might only have the body/justification for - // a block - if request.RequestField(messages.RequestedDataHeader) && !isResponseAChain(response.BlockData) { - logger.Warnf("response from %s is not a chain", result.who) - repChanges = append(repChanges, Change{ - who: result.who, + err := validateResponseFields(request, response.BlockData) + if err != nil { + logger.Warnf("validating fields: %s", err) + // TODO: check the reputation change for nil body in response + // and nil justification in response + if errors.Is(err, errNilHeaderInResponse) { + repChange = &Change{ + who: result.Who, rep: peerset.ReputationChange{ Value: peerset.IncompleteHeaderValue, Reason: peerset.IncompleteHeaderReason, }, - }) - continue + } + return } + } - for _, block := range response.BlockData { - if slices.Contains(badBlocks, block.Hash.String()) { - logger.Warnf("%s sent a known bad block: #%d (%s)", - result.who, block.Number(), block.Hash.String()) - - peersToBlock = append(peersToBlock, result.who) - repChanges = append(repChanges, Change{ - who: result.who, - rep: peerset.ReputationChange{ - Value: peerset.BadBlockAnnouncementValue, - Reason: peerset.BadBlockAnnouncementReason, - }, - }) - - continue resultLoop - } + // only check if the block data in the response forms a chain if it contains the headers + // of each block, othewise the response might only have the body/justification for a block + if request.RequestField(messages.RequestedDataHeader) && !isResponseAChain(response.BlockData) { + logger.Warnf("response from %s is not a chain", result.Who) + repChange = &Change{ + who: result.Who, + rep: peerset.ReputationChange{ + Value: peerset.IncompleteHeaderValue, + Reason: peerset.IncompleteHeaderReason, + }, } + return + } - validRes = append(validRes, RequestResponseData{ - req: request, - responseData: response.BlockData, - }) + for _, block := range response.BlockData { + if slices.Contains(badBlocks, block.Hash.String()) { + logger.Warnf("%s sent a known bad block: #%d (%s)", + result.Who, block.Number(), block.Hash.String()) + + blockPeer = true + repChange = &Change{ + who: result.Who, + rep: peerset.ReputationChange{ + Value: peerset.BadBlockAnnouncementValue, + Reason: peerset.BadBlockAnnouncementReason, + }, + } + return + } } - return repChanges, peersToBlock, validRes + validRes = &RequestResponseData{ + req: request, + responseData: response.BlockData, + } + return } // sortFragmentsOfChain will organise the fragments diff --git a/dot/sync/fullsync_test.go b/dot/sync/fullsync_test.go index 0c9bbd4122..9b1bc1e794 100644 --- a/dot/sync/fullsync_test.go +++ b/dot/sync/fullsync_test.go @@ -7,6 +7,10 @@ import ( "container/list" "testing" + "gopkg.in/yaml.v3" + + _ "embed" + "github.com/ChainSafe/gossamer/dot/network" "github.com/ChainSafe/gossamer/dot/network/messages" "github.com/ChainSafe/gossamer/dot/peerset" @@ -15,9 +19,6 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" - "gopkg.in/yaml.v3" - - _ "embed" ) //go:embed testdata/westend_blocks.yaml @@ -69,8 +70,7 @@ func TestFullSyncNextActions(t *testing.T) { task, err := fs.NextActions() require.NoError(t, err) - require.Len(t, task, int(maxRequestsAllowed)) - request := task[0].request.(*messages.BlockRequestMessage) + request := task[0].(*syncTask).request.(*messages.BlockRequestMessage) require.Equal(t, uint(1), request.StartingBlock.RawValue()) require.Equal(t, uint32(128), *request.Max) }) @@ -171,7 +171,7 @@ func TestFullSyncNextActions(t *testing.T) { task, err := fs.NextActions() require.NoError(t, err) - require.Equal(t, task[0].request, tt.expectedTasks[0]) + require.Equal(t, task[0].(*syncTask).request, tt.expectedTasks[0]) require.Equal(t, fs.requestQueue.Len(), tt.expectedQueueLen) }) } @@ -192,37 +192,45 @@ func TestFullSyncProcess(t *testing.T) { require.NoError(t, err) t.Run("requested_max_but_received_less_blocks", func(t *testing.T) { - syncTaskResults := []*SyncTaskResult{ + ctrl := gomock.NewController(t) + requestMaker := NewMockRequestMaker(ctrl) + + syncTaskResults := []TaskResult{ // first task // 1 -> 10 { - who: peer.ID("peerA"), - request: messages.NewBlockRequest(*messages.NewFromBlock(uint(1)), 127, - messages.BootstrapRequestData, messages.Ascending), - completed: true, - response: fstTaskBlockResponse, + Who: peer.ID("peerA"), + Task: &syncTask{ + request: messages.NewBlockRequest(*messages.NewFromBlock(uint(1)), 127, + messages.BootstrapRequestData, messages.Ascending), + requestMaker: requestMaker, + }, + Completed: true, + Result: fstTaskBlockResponse, }, // there is gap from 11 -> 128 // second task // 129 -> 256 { - who: peer.ID("peerA"), - request: messages.NewBlockRequest(*messages.NewFromBlock(uint(129)), 127, - messages.BootstrapRequestData, messages.Ascending), - completed: true, - response: sndTaskBlockResponse, + Who: peer.ID("peerA"), + Task: &syncTask{ + request: messages.NewBlockRequest(*messages.NewFromBlock(uint(129)), 127, + messages.BootstrapRequestData, messages.Ascending), + requestMaker: requestMaker, + }, + Completed: true, + Result: sndTaskBlockResponse, }, } genesisHeader := types.NewHeader(fstTaskBlockResponse.BlockData[0].Header.ParentHash, common.Hash{}, common.Hash{}, 0, types.NewDigest()) - ctrl := gomock.NewController(t) mockBlockState := NewMockBlockState(ctrl) mockBlockState.EXPECT().GetHighestFinalisedHeader(). Return(genesisHeader, nil). - Times(4) + Times(5) mockBlockState.EXPECT(). HasHeader(fstTaskBlockResponse.BlockData[0].Header.ParentHash). @@ -247,7 +255,12 @@ func TestFullSyncProcess(t *testing.T) { fs := NewFullSyncStrategy(cfg) fs.blockImporter = mockImporter - done, _, _, err := fs.Process(syncTaskResults) + results := make(chan TaskResult, len(syncTaskResults)) + for _, result := range syncTaskResults { + results <- result + } + + done, _, _, err := fs.Process(results) require.NoError(t, err) require.False(t, done) @@ -271,18 +284,19 @@ func TestFullSyncProcess(t *testing.T) { err = ancestorSearchResponse.Decode(common.MustHexToBytes(westendBlocks.Blocks1To128)) require.NoError(t, err) - syncTaskResults = []*SyncTaskResult{ + results <- TaskResult{ // ancestor search task // 128 -> 1 - { - who: peer.ID("peerA"), - request: expectedAncestorRequest, - completed: true, - response: ancestorSearchResponse, + Who: peer.ID("peerA"), + Task: &syncTask{ + request: expectedAncestorRequest, + requestMaker: requestMaker, }, + Completed: true, + Result: ancestorSearchResponse, } - done, _, _, err = fs.Process(syncTaskResults) + done, _, _, err = fs.Process(results) require.NoError(t, err) require.False(t, done) @@ -293,7 +307,7 @@ func TestFullSyncProcess(t *testing.T) { } func TestFullSyncBlockAnnounce(t *testing.T) { - t.Run("announce_a_far_block_without_any_commom_ancestor", func(t *testing.T) { + t.Run("announce_a_far_block_without_any_common_ancestor", func(t *testing.T) { highestFinalizedHeader := &types.Header{ ParentHash: common.BytesToHash([]byte{0}), StateRoot: common.BytesToHash([]byte{3, 3, 3, 3}), @@ -347,7 +361,7 @@ func TestFullSyncBlockAnnounce(t *testing.T) { require.Zero(t, fs.requestQueue.Len()) }) - t.Run("announce_closer_valid_block_without_any_commom_ancestor", func(t *testing.T) { + t.Run("announce_closer_valid_block_without_any_common_ancestor", func(t *testing.T) { highestFinalizedHeader := &types.Header{ ParentHash: common.BytesToHash([]byte{0}), StateRoot: common.BytesToHash([]byte{3, 3, 3, 3}), @@ -457,7 +471,7 @@ func TestFullSyncBlockAnnounce(t *testing.T) { requests := make([]messages.P2PMessage, len(tasks)) for idx, task := range tasks { - requests[idx] = task.request + requests[idx] = task.(*syncTask).request } block17 := types.NewHeader(announceOfBlock17.ParentHash, diff --git a/dot/sync/service.go b/dot/sync/service.go index 8873de0002..7cc28b0410 100644 --- a/dot/sync/service.go +++ b/dot/sync/service.go @@ -10,6 +10,7 @@ import ( "time" "github.com/ChainSafe/gossamer/dot/network" + "github.com/ChainSafe/gossamer/dot/network/messages" "github.com/ChainSafe/gossamer/dot/peerset" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/internal/log" @@ -88,12 +89,27 @@ type Change struct { type Strategy interface { OnBlockAnnounce(from peer.ID, msg *network.BlockAnnounceMessage) (repChange *Change, err error) OnBlockAnnounceHandshake(from peer.ID, msg *network.BlockAnnounceHandshake) error - NextActions() ([]*SyncTask, error) - Process(results []*SyncTaskResult) (done bool, repChanges []Change, blocks []peer.ID, err error) + NextActions() ([]Task, error) + Process(results <-chan TaskResult) (done bool, repChanges []Change, blocks []peer.ID, err error) ShowMetrics() IsSynced() bool } +type syncTask struct { + requestMaker network.RequestMaker + request messages.P2PMessage +} + +func (s *syncTask) ID() TaskID { + return TaskID(s.request.String()) +} + +func (s *syncTask) Do(p peer.ID) (Result, error) { + response := messages.BlockResponseMessage{} + err := s.requestMaker.Do(p, s.request, &response) + return &response, err +} + type SyncService struct { mu sync.Mutex wg sync.WaitGroup @@ -103,7 +119,7 @@ type SyncService struct { currentStrategy Strategy defaultStrategy Strategy - workerPool *syncWorkerPool + workerPool WorkerPool waitPeersDuration time.Duration minPeers int slotDuration time.Duration @@ -119,6 +135,11 @@ func NewSyncService(cfgs ...ServiceConfig) *SyncService { waitPeersDuration: waitPeersDefaultTimeout, stopCh: make(chan struct{}), seenBlockSyncRequests: lrucache.NewLRUCache[common.Hash, uint](100), + workerPool: NewWorkerPool(WorkerPoolConfig{ + MaxRetries: 5, + // TODO: This should depend on the actual configuration of the currently used sync strategy. + Capacity: defaultNumOfTasks * 10, + }), } for _, cfg := range cfgs { @@ -135,7 +156,7 @@ func (s *SyncService) waitWorkers() { } for { - total := s.workerPool.totalWorkers() + total := s.workerPool.NumPeers() if total >= s.minPeers { return } @@ -164,6 +185,7 @@ func (s *SyncService) Start() error { } func (s *SyncService) Stop() error { + s.workerPool.Shutdown() close(s.stopCh) s.wg.Wait() return nil @@ -171,7 +193,9 @@ func (s *SyncService) Stop() error { func (s *SyncService) HandleBlockAnnounceHandshake(from peer.ID, msg *network.BlockAnnounceHandshake) error { logger.Infof("receiving a block announce handshake from %s", from.String()) - if err := s.workerPool.fromBlockAnnounceHandshake(from); err != nil { + logger.Infof("len(s.workerPool.Results())=%d", len(s.workerPool.Results())) // TODO: remove + if err := s.workerPool.AddPeer(from); err != nil { + logger.Warnf("failed to add peer to worker pool: %s", err) return err } @@ -199,7 +223,7 @@ func (s *SyncService) HandleBlockAnnounce(from peer.ID, msg *network.BlockAnnoun func (s *SyncService) OnConnectionClosed(who peer.ID) { logger.Tracef("removing peer worker: %s", who.String()) - s.workerPool.removeWorker(who) + s.workerPool.RemovePeer(who) } func (s *SyncService) IsSynced() bool { @@ -249,19 +273,20 @@ func (s *SyncService) runStrategy() { finalisedHeader, err := s.blockState.GetHighestFinalisedHeader() if err != nil { - logger.Criticalf("getting highest finalized header: %w", err) + logger.Criticalf("getting highest finalized header: %s", err) return } bestBlockHeader, err := s.blockState.BestBlockHeader() if err != nil { - logger.Criticalf("getting best block header: %w", err) + logger.Criticalf("getting best block header: %s", err) return } logger.Infof( - "🚣 currently syncing, %d peers connected, finalized #%d (%s), best #%d (%s)", + "🚣 currently syncing, %d peers connected, %d peers in the worker pool, finalized #%d (%s), best #%d (%s)", len(s.network.AllConnectedPeersIDs()), + s.workerPool.NumPeers(), finalisedHeader.Number, finalisedHeader.Hash().Short(), bestBlockHeader.Number, @@ -279,8 +304,13 @@ func (s *SyncService) runStrategy() { return } - results := s.workerPool.submitRequests(tasks) - done, repChanges, peersToIgnore, err := s.currentStrategy.Process(results) + _, err = s.workerPool.SubmitBatch(tasks) + if err != nil { + logger.Criticalf("current sync strategy next actions failed with: %s", err.Error()) + return + } + + done, repChanges, peersToIgnore, err := s.currentStrategy.Process(s.workerPool.Results()) if err != nil { logger.Criticalf("current sync strategy failed with: %s", err.Error()) return @@ -291,7 +321,7 @@ func (s *SyncService) runStrategy() { } for _, block := range peersToIgnore { - s.workerPool.ignorePeerAsWorker(block) + s.workerPool.IgnorePeer(block) } s.currentStrategy.ShowMetrics() diff --git a/dot/sync/worker_pool.go b/dot/sync/worker_pool.go index b11b726db7..a85a33053b 100644 --- a/dot/sync/worker_pool.go +++ b/dot/sync/worker_pool.go @@ -1,191 +1,358 @@ -// Copyright 2023 ChainSafe Systems (ON) +// Copyright 2024 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only package sync import ( + "container/list" + "context" "errors" + "fmt" "sync" "time" - "github.com/ChainSafe/gossamer/dot/network" - "github.com/ChainSafe/gossamer/dot/network/messages" "github.com/libp2p/go-libp2p/core/peer" - "golang.org/x/exp/maps" ) +const defaultWorkerPoolCapacity = 100 + var ( - ErrNoPeersToMakeRequest = errors.New("no peers to make requests") - ErrPeerIgnored = errors.New("peer ignored") + ErrNoPeers = errors.New("no peers available") + ErrPeerIgnored = errors.New("peer ignored") ) -const ( - punishmentBaseTimeout = 5 * time.Minute - maxRequestsAllowed uint = 3 -) +type TaskID string +type Result any -type SyncTask struct { - requestMaker network.RequestMaker - request messages.P2PMessage - response messages.P2PMessage +type Task interface { + ID() TaskID + Do(p peer.ID) (Result, error) } -type SyncTaskResult struct { - who peer.ID - completed bool - request messages.P2PMessage - response messages.P2PMessage +type TaskResult struct { + Task Task + Completed bool + Result Result + Error error + Retries uint + Who peer.ID } -type syncWorkerPool struct { - mtx sync.RWMutex +func (t TaskResult) Failed() bool { + return t.Error != nil +} - network Network - workers map[peer.ID]struct{} - ignorePeers map[peer.ID]struct{} +type BatchStatus struct { + Failed map[TaskID]TaskResult + Success map[TaskID]TaskResult } -func newSyncWorkerPool(net Network) *syncWorkerPool { - swp := &syncWorkerPool{ - network: net, - workers: make(map[peer.ID]struct{}), - ignorePeers: make(map[peer.ID]struct{}), +func (bs BatchStatus) Completed(todo int) bool { + if len(bs.Failed)+len(bs.Success) < todo { + return false } - return swp + for _, tr := range bs.Failed { + if !tr.Completed { + return false + } + } + + for _, tr := range bs.Success { + if !tr.Completed { + return false + } + } + + return true } -// fromBlockAnnounceHandshake stores the peer which send us a handshake as -// a possible source for requesting blocks/state/warp proofs -func (s *syncWorkerPool) fromBlockAnnounceHandshake(who peer.ID) error { - s.mtx.Lock() - defer s.mtx.Unlock() +type BatchID string + +type WorkerPool interface { + SubmitBatch(tasks []Task) (id BatchID, err error) + GetBatch(id BatchID) (status BatchStatus, ok bool) + Results() chan TaskResult + AddPeer(p peer.ID) error + RemovePeer(p peer.ID) + IgnorePeer(p peer.ID) + NumPeers() int + Shutdown() +} - if _, ok := s.ignorePeers[who]; ok { - return ErrPeerIgnored +type WorkerPoolConfig struct { + Capacity uint + MaxRetries uint +} + +// NewWorkerPool creates a new worker pool with the given configuration. +func NewWorkerPool(cfg WorkerPoolConfig) WorkerPool { + ctx, cancel := context.WithCancel(context.Background()) + + if cfg.Capacity == 0 { + cfg.Capacity = defaultWorkerPoolCapacity } - _, has := s.workers[who] - if has { - return nil + return &workerPool{ + maxRetries: cfg.MaxRetries, + ignoredPeers: make(map[peer.ID]struct{}), + statuses: make(map[BatchID]BatchStatus), + resChan: make(chan TaskResult, cfg.Capacity), + ctx: ctx, + cancel: cancel, } +} - s.workers[who] = struct{}{} - logger.Tracef("potential worker added, total in the pool %d", len(s.workers)) - return nil +type workerPool struct { + mtx sync.RWMutex + wg sync.WaitGroup + + maxRetries uint + peers list.List + ignoredPeers map[peer.ID]struct{} + statuses map[BatchID]BatchStatus + resChan chan TaskResult + ctx context.Context + cancel context.CancelFunc } -// submitRequests blocks until all tasks have been completed or there are no workers -// left in the pool to retry failed tasks -func (s *syncWorkerPool) submitRequests(tasks []*SyncTask) []*SyncTaskResult { - if len(tasks) == 0 { - return nil +// SubmitBatch accepts a list of tasks and immediately returns a batch ID. The batch ID can be used to query the status +// of the batch using [GetBatchStatus]. +// TODO +// If tasks are submitted faster than they are completed, resChan will run full, blocking the calling goroutine. +// Ideally this method would provide backpressure to the caller in that case. The rejected tasks should then stay in +// FullSyncStrategy.requestQueue until the next round. But this would need to be supported in all sync strategies. +func (w *workerPool) SubmitBatch(tasks []Task) (id BatchID, err error) { + w.mtx.Lock() + defer w.mtx.Unlock() + + bID := BatchID(fmt.Sprintf("%d", time.Now().UnixNano())) + + w.statuses[bID] = BatchStatus{ + Failed: make(map[TaskID]TaskResult), + Success: make(map[TaskID]TaskResult), } - s.mtx.RLock() - defer s.mtx.RUnlock() + w.wg.Add(1) + go func() { + defer w.wg.Done() + w.executeBatch(tasks, bID) + }() - pids := maps.Keys(s.workers) - workerPool := make(chan peer.ID, len(pids)) - for _, worker := range pids { - workerPool <- worker + return bID, nil +} + +// GetBatch returns the status of a batch previously submitted using [SubmitBatch]. +func (w *workerPool) GetBatch(id BatchID) (status BatchStatus, ok bool) { + w.mtx.RLock() + defer w.mtx.RUnlock() + + status, ok = w.statuses[id] + return +} + +// Results returns a channel that can be used to receive the results of completed tasks. +func (w *workerPool) Results() chan TaskResult { + return w.resChan +} + +// AddPeer adds a peer to the worker pool unless it has been ignored previously. +func (w *workerPool) AddPeer(who peer.ID) error { + w.mtx.Lock() + defer w.mtx.Unlock() + + if _, ok := w.ignoredPeers[who]; ok { + return ErrPeerIgnored + } + + for e := w.peers.Front(); e != nil; e = e.Next() { + if e.Value.(peer.ID) == who { + return nil + } } - failedTasks := make(chan *SyncTask, len(tasks)) - results := make(chan *SyncTaskResult, len(tasks)) + w.peers.PushBack(who) + logger.Tracef("peer added, total in the pool %d", w.peers.Len()) + return nil +} + +// RemovePeer removes a peer from the worker pool. +func (w *workerPool) RemovePeer(who peer.ID) { + w.mtx.Lock() + defer w.mtx.Unlock() + + w.removePeer(who) +} + +// IgnorePeer removes a peer from the worker pool and prevents it from being added again. +func (w *workerPool) IgnorePeer(who peer.ID) { + w.mtx.Lock() + defer w.mtx.Unlock() + + w.removePeer(who) + w.ignoredPeers[who] = struct{}{} +} + +// NumPeers returns the number of peers in the worker pool, both busy and free. +func (w *workerPool) NumPeers() int { + w.mtx.RLock() + defer w.mtx.RUnlock() - var wg sync.WaitGroup - for _, task := range tasks { - wg.Add(1) - go func(t *SyncTask) { - defer wg.Done() - executeTask(t, workerPool, failedTasks, results) - }(task) + return w.peers.Len() +} + +// Shutdown stops the worker pool and waits for all tasks to complete. +func (w *workerPool) Shutdown() { + w.cancel() + w.wg.Wait() +} + +func (w *workerPool) executeBatch(tasks []Task, bID BatchID) { + batchResults := make(chan TaskResult, len(tasks)) + + for _, t := range tasks { + w.wg.Add(1) + go func(t Task) { + defer w.wg.Done() + w.executeTask(t, batchResults) + }(t) } - wg.Add(1) - go func() { - defer wg.Done() - for task := range failedTasks { - if len(workerPool) > 0 { - wg.Add(1) - go func(t *SyncTask) { - defer wg.Done() - executeTask(t, workerPool, failedTasks, results) - }(task) + for { + select { + case <-w.ctx.Done(): + return + + case tr := <-batchResults: + if tr.Failed() { + w.handleFailedTask(tr, bID, batchResults) } else { - results <- &SyncTaskResult{ - completed: false, - request: task.request, - response: nil, - } + w.handleSuccessfulTask(tr, bID) } - } - }() - allResults := make(chan []*SyncTaskResult, 1) - wg.Add(1) - go func(expectedResults int) { - defer wg.Done() - var taskResults []*SyncTaskResult - - for result := range results { - taskResults = append(taskResults, result) - if len(taskResults) == expectedResults { - close(failedTasks) - break + if w.batchCompleted(bID, len(tasks)) { + return } } + } +} - allResults <- taskResults - }(len(tasks)) - - wg.Wait() - close(workerPool) - close(results) +func (w *workerPool) executeTask(task Task, ch chan TaskResult) { + if errors.Is(w.ctx.Err(), context.Canceled) { + logger.Tracef("[CANCELED] task=%s, shutting down", task.ID()) + return + } - return <-allResults -} + who, err := w.reservePeer() + if errors.Is(err, ErrNoPeers) { + logger.Tracef("no peers available for task=%s", task.ID()) + ch <- TaskResult{Task: task, Error: ErrNoPeers} + return + } -func executeTask(task *SyncTask, workerPool chan peer.ID, failedTasks chan *SyncTask, results chan *SyncTaskResult) { - worker := <-workerPool - logger.Infof("[EXECUTING] worker %s", worker) + logger.Infof("[EXECUTING] task=%s", task.ID()) - err := task.requestMaker.Do(worker, task.request, task.response) + result, err := task.Do(who) if err != nil { - logger.Infof("[ERR] worker %s, request: %s, err: %s", worker, task.request.String(), err.Error()) - failedTasks <- task + logger.Tracef("[FAILED] task=%s peer=%s, err=%s", task.ID(), who, err.Error()) } else { - logger.Infof("[FINISHED] worker %s, request: %s", worker, task.request.String()) - workerPool <- worker - results <- &SyncTaskResult{ - who: worker, - completed: true, - request: task.request, - response: task.response, + logger.Tracef("[FINISHED] task=%s peer=%s", task.ID(), who) + } + + w.mtx.Lock() + w.peers.PushBack(who) + w.mtx.Unlock() + + ch <- TaskResult{ + Task: task, + Who: who, + Result: result, + Error: err, + Retries: 0, + } +} + +func (w *workerPool) reservePeer() (who peer.ID, err error) { + w.mtx.Lock() + defer w.mtx.Unlock() + + peerElement := w.peers.Front() + + if peerElement == nil { + return who, ErrNoPeers + } + + w.peers.Remove(peerElement) + return peerElement.Value.(peer.ID), nil +} + +func (w *workerPool) removePeer(who peer.ID) { + var toRemove *list.Element + for e := w.peers.Front(); e != nil; e = e.Next() { + if e.Value.(peer.ID) == who { + toRemove = e + break } } + + if toRemove != nil { + w.peers.Remove(toRemove) + } } -func (s *syncWorkerPool) ignorePeerAsWorker(who peer.ID) { - s.mtx.Lock() - defer s.mtx.Unlock() +func (w *workerPool) handleSuccessfulTask(tr TaskResult, batchID BatchID) { + w.mtx.Lock() + defer w.mtx.Unlock() + + tID := tr.Task.ID() + + if failedTr, ok := w.statuses[batchID].Failed[tID]; ok { + tr.Retries = failedTr.Retries + 1 + delete(w.statuses[batchID].Failed, tID) + } - delete(s.workers, who) - s.ignorePeers[who] = struct{}{} + tr.Completed = true + w.statuses[batchID].Success[tID] = tr + logger.Infof("handleSuccessfulTask(): len(w.resChan)=%d", len(w.resChan)) // TODO: remove + w.resChan <- tr } -func (s *syncWorkerPool) removeWorker(who peer.ID) { - s.mtx.Lock() - defer s.mtx.Unlock() +func (w *workerPool) handleFailedTask(tr TaskResult, batchID BatchID, batchResults chan TaskResult) { + w.mtx.Lock() + defer w.mtx.Unlock() + + tID := tr.Task.ID() + + if oldTr, ok := w.statuses[batchID].Failed[tID]; ok { + // It is only considered a retry if the task was actually executed. + if errors.Is(oldTr.Error, ErrNoPeers) { + // TODO Should we sleep a bit to wait for peers? + } else { + tr.Retries = oldTr.Retries + 1 + tr.Completed = tr.Retries >= w.maxRetries + } + } + + w.statuses[batchID].Failed[tID] = tr + + if tr.Completed { + logger.Infof("handleFailedTask(): len(w.resChan)=%d", len(w.resChan)) // TODO: remove + w.resChan <- tr + return + } - delete(s.workers, who) + // retry task + w.wg.Add(1) + go func() { + defer w.wg.Done() + w.executeTask(tr.Task, batchResults) + }() } -// totalWorkers only returns available or busy workers -func (s *syncWorkerPool) totalWorkers() (total int) { - s.mtx.RLock() - defer s.mtx.RUnlock() +func (w *workerPool) batchCompleted(id BatchID, todo int) bool { + w.mtx.Lock() + defer w.mtx.Unlock() - return len(s.workers) + b, ok := w.statuses[id] + return !ok || b.Completed(todo) } diff --git a/dot/sync/worker_pool_test.go b/dot/sync/worker_pool_test.go new file mode 100644 index 0000000000..676b787e78 --- /dev/null +++ b/dot/sync/worker_pool_test.go @@ -0,0 +1,268 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package sync + +import ( + "errors" + "fmt" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/peer" + "github.com/stretchr/testify/assert" +) + +type mockTask struct { + id TaskID + err error + execCount uint + succeedAfter uint +} + +func (m *mockTask) ID() TaskID { + return m.id +} + +func (m *mockTask) Do(p peer.ID) (Result, error) { + time.Sleep(time.Millisecond * 100) // simulate network roundtrip + defer func() { + m.execCount++ + }() + + res := Result(fmt.Sprintf("%s - %s great success!", m.id, p)) + if m.err != nil { + if m.succeedAfter > 0 && m.execCount >= m.succeedAfter { + return res, nil + } + return nil, m.err + } + return res, nil +} + +func (m *mockTask) String() string { + return fmt.Sprintf("mockTask %s", m.id) +} + +func makeTasksAndPeers(num, idOffset int) ([]Task, []peer.ID) { + tasks := make([]Task, num) + peers := make([]peer.ID, num) + + for i := 0; i < num; i++ { + tasks[i] = &mockTask{id: TaskID(fmt.Sprintf("t-%d", i+idOffset))} + peers[i] = peer.ID(fmt.Sprintf("p-%d", i+idOffset)) + } + return tasks, peers +} + +func waitForCompletion(wp WorkerPool, numTasks int) { + resultsReceived := 0 + + for { + <-wp.Results() + resultsReceived++ + + if resultsReceived == numTasks { + break + } + } +} + +func TestWorkerPoolHappyPath(t *testing.T) { + numTasks := 10 + + var setup = func() (WorkerPool, []Task) { + tasks, peers := makeTasksAndPeers(numTasks, 0) + wp := NewWorkerPool(WorkerPoolConfig{}) + + for _, who := range peers { + err := wp.AddPeer(who) + assert.NoError(t, err) + } + + return wp, tasks + } + + t.Run("receive_results_on_channel", func(t *testing.T) { + wp, tasks := setup() + results := make([]TaskResult, 0, numTasks) + _, err := wp.SubmitBatch(tasks) + + assert.NoError(t, err) + + for { + result := <-wp.Results() + assert.True(t, result.Completed) + assert.False(t, result.Failed()) + assert.Equal(t, uint(0), result.Retries) + + results = append(results, result) + if len(results) == numTasks { + break + } + } + }) + + t.Run("check_batch_status_on_completion", func(t *testing.T) { + wp, tasks := setup() + batchID, err := wp.SubmitBatch(tasks) + assert.NoError(t, err) + + waitForCompletion(wp, numTasks) + status, ok := wp.GetBatch(batchID) + + assert.True(t, ok) + assert.True(t, status.Completed(numTasks)) + assert.Equal(t, numTasks, len(status.Success)) + assert.Equal(t, 0, len(status.Failed)) + }) +} + +func TestWorkerPoolPeerHandling(t *testing.T) { + numTasks := 3 + + t.Run("accepts_batch_without_any_peers", func(t *testing.T) { + tasks, _ := makeTasksAndPeers(numTasks, 0) + wp := NewWorkerPool(WorkerPoolConfig{}) + + _, err := wp.SubmitBatch(tasks) + assert.NoError(t, err) + + wp.Shutdown() + }) + + t.Run("completes_batch_with_fewer_peers_than_tasks", func(t *testing.T) { + tasks, peers := makeTasksAndPeers(numTasks, 0) + wp := NewWorkerPool(WorkerPoolConfig{}) + assert.NoError(t, wp.AddPeer(peers[0])) + assert.NoError(t, wp.AddPeer(peers[1])) + + bID, err := wp.SubmitBatch(tasks) + assert.NoError(t, err) + + waitForCompletion(wp, numTasks) + status, ok := wp.GetBatch(bID) + assert.True(t, ok) + assert.True(t, status.Completed(numTasks)) + assert.Equal(t, numTasks, len(status.Success)) + assert.Equal(t, 0, len(status.Failed)) + }) + + t.Run("refuses_to_re_add_ignored_peer", func(t *testing.T) { + _, peers := makeTasksAndPeers(numTasks, 0) + wp := NewWorkerPool(WorkerPoolConfig{}) + + for _, who := range peers { + err := wp.AddPeer(who) + assert.NoError(t, err) + } + assert.Equal(t, len(peers), wp.NumPeers()) + + badPeer := peers[2] + wp.IgnorePeer(badPeer) + assert.Equal(t, len(peers)-1, wp.NumPeers()) + + err := wp.AddPeer(badPeer) + assert.ErrorIs(t, err, ErrPeerIgnored) + assert.Equal(t, len(peers)-1, wp.NumPeers()) + }) +} + +func TestWorkerPoolTaskFailures(t *testing.T) { + numTasks := 3 + taskErr := errors.New("kaput") + + setup := func(maxRetries uint) (failOnce *mockTask, failTwice *mockTask, batchID BatchID, wp WorkerPool) { + tasks, peers := makeTasksAndPeers(numTasks, 0) + + failOnce = tasks[1].(*mockTask) + failOnce.err = taskErr + failOnce.succeedAfter = 1 + + failTwice = tasks[2].(*mockTask) + failTwice.err = taskErr + failTwice.succeedAfter = 2 + + wp = NewWorkerPool(WorkerPoolConfig{MaxRetries: maxRetries}) + for _, who := range peers { + err := wp.AddPeer(who) + assert.NoError(t, err) + } + + var err error + batchID, err = wp.SubmitBatch(tasks) + assert.NoError(t, err) + return + } + + t.Run("retries_failed_tasks", func(t *testing.T) { + failOnce, failTwice, batchID, wp := setup(10) + waitForCompletion(wp, numTasks) + + status, ok := wp.GetBatch(batchID) + assert.True(t, ok) + assert.True(t, status.Completed(numTasks)) + assert.Equal(t, numTasks, len(status.Success)) + assert.Equal(t, 0, len(status.Failed)) + + assert.Nil(t, status.Failed[failOnce.ID()].Error) + assert.Equal(t, uint(1), status.Success[failOnce.ID()].Retries) + + assert.Nil(t, status.Failed[failTwice.ID()].Error) + assert.Equal(t, uint(2), status.Success[failTwice.ID()].Retries) + }) + + t.Run("honours_max_retries", func(t *testing.T) { + failOnce, failTwice, batchID, wp := setup(1) + waitForCompletion(wp, numTasks) + + status, ok := wp.GetBatch(batchID) + assert.True(t, ok) + assert.True(t, status.Completed(numTasks)) + assert.Equal(t, numTasks-1, len(status.Success)) + assert.Equal(t, 1, len(status.Failed)) + + assert.Nil(t, status.Failed[failOnce.ID()].Error) + assert.Equal(t, uint(1), status.Success[failOnce.ID()].Retries) + + assert.ErrorIs(t, taskErr, status.Failed[failTwice.ID()].Error) + assert.Equal(t, uint(1), status.Failed[failTwice.ID()].Retries) + }) +} + +func TestWorkerPoolMultipleBatches(t *testing.T) { + b1NumTasks := 10 + b2NumTasks := 12 + + t.Run("completes_all_batches", func(t *testing.T) { + b1Tasks, b1Peers := makeTasksAndPeers(b1NumTasks, 0) + b2Tasks, b2Peers := makeTasksAndPeers(b2NumTasks, b1NumTasks) + peers := append(b1Peers, b2Peers...) + + wp := NewWorkerPool(WorkerPoolConfig{}) + for _, who := range peers { + err := wp.AddPeer(who) + assert.NoError(t, err) + } + + b1ID, err := wp.SubmitBatch(b1Tasks) + assert.NoError(t, err) + + b2ID, err := wp.SubmitBatch(b2Tasks) + assert.NoError(t, err) + + waitForCompletion(wp, b1NumTasks+b2NumTasks) + + b1Status, ok := wp.GetBatch(b1ID) + assert.True(t, ok) + assert.True(t, b1Status.Completed(b1NumTasks)) + assert.Equal(t, b1NumTasks, len(b1Status.Success)) + assert.Equal(t, 0, len(b1Status.Failed)) + + b2Status, ok := wp.GetBatch(b2ID) + assert.True(t, ok) + assert.True(t, b2Status.Completed(b2NumTasks)) + assert.Equal(t, b2NumTasks, len(b2Status.Success)) + assert.Equal(t, 0, len(b2Status.Failed)) + }) +}