diff --git a/p2p/discover/lookup.go b/p2p/discover/lookup.go index 09808b71e07..9cca0118ac6 100644 --- a/p2p/discover/lookup.go +++ b/p2p/discover/lookup.go @@ -27,6 +27,7 @@ import ( // lookup performs a network search for nodes close to the given target. It approaches the // target by querying nodes that are closer to it on each iteration. The given target does // not need to be an actual node identifier. +// lookup on an empty table will return immediately with no nodes. type lookup struct { tab *Table queryfunc queryFunc @@ -49,11 +50,15 @@ func newLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *l result: nodesByDistance{target: target}, replyCh: make(chan []*enode.Node, alpha), cancelCh: ctx.Done(), - queries: -1, } // Don't query further if we hit ourself. // Unlikely to happen often in practice. it.asked[tab.self().ID()] = true + it.seen[tab.self().ID()] = true + + // Initialize the lookup with nodes from table. + closest := it.tab.findnodeByID(it.result.target, bucketSize, false) + it.addNodes(closest.entries) return it } @@ -64,22 +69,19 @@ func (it *lookup) run() []*enode.Node { return it.result.entries } +func (it *lookup) empty() bool { + return len(it.replyBuffer) == 0 +} + // advance advances the lookup until any new nodes have been found. // It returns false when the lookup has ended. func (it *lookup) advance() bool { for it.startQueries() { select { case nodes := <-it.replyCh: - it.replyBuffer = it.replyBuffer[:0] - for _, n := range nodes { - if n != nil && !it.seen[n.ID()] { - it.seen[n.ID()] = true - it.result.push(n, bucketSize) - it.replyBuffer = append(it.replyBuffer, n) - } - } it.queries-- - if len(it.replyBuffer) > 0 { + it.addNodes(nodes) + if !it.empty() { return true } case <-it.cancelCh: @@ -89,6 +91,17 @@ func (it *lookup) advance() bool { return false } +func (it *lookup) addNodes(nodes []*enode.Node) { + it.replyBuffer = it.replyBuffer[:0] + for _, n := range nodes { + if n != nil && !it.seen[n.ID()] { + it.seen[n.ID()] = true + it.result.push(n, bucketSize) + it.replyBuffer = append(it.replyBuffer, n) + } + } +} + func (it *lookup) shutdown() { for it.queries > 0 { <-it.replyCh @@ -103,20 +116,6 @@ func (it *lookup) startQueries() bool { return false } - // The first query returns nodes from the local table. - if it.queries == -1 { - closest := it.tab.findnodeByID(it.result.target, bucketSize, false) - // Avoid finishing the lookup too quickly if table is empty. It'd be better to wait - // for the table to fill in this case, but there is no good mechanism for that - // yet. - if len(closest.entries) == 0 { - it.slowdown() - } - it.queries = 1 - it.replyCh <- closest.entries - return true - } - // Ask the closest nodes that we haven't asked yet. for i := 0; i < len(it.result.entries) && it.queries < alpha; i++ { n := it.result.entries[i] @@ -130,15 +129,6 @@ func (it *lookup) startQueries() bool { return it.queries > 0 } -func (it *lookup) slowdown() { - sleep := time.NewTimer(1 * time.Second) - defer sleep.Stop() - select { - case <-sleep.C: - case <-it.tab.closeReq: - } -} - func (it *lookup) query(n *enode.Node, reply chan<- []*enode.Node) { r, err := it.queryfunc(n) if !errors.Is(err, errClosed) { // avoid recording failures on shutdown. @@ -153,12 +143,16 @@ func (it *lookup) query(n *enode.Node, reply chan<- []*enode.Node) { // lookupIterator performs lookup operations and iterates over all seen nodes. // When a lookup finishes, a new one is created through nextLookup. +// LookupIterator waits for table initialization and triggers a table refresh +// when necessary. + type lookupIterator struct { - buffer []*enode.Node - nextLookup lookupFunc - ctx context.Context - cancel func() - lookup *lookup + buffer []*enode.Node + nextLookup lookupFunc + ctx context.Context + cancel func() + lookup *lookup + tabRefreshing <-chan struct{} } type lookupFunc func(ctx context.Context) *lookup @@ -182,6 +176,7 @@ func (it *lookupIterator) Next() bool { if len(it.buffer) > 0 { it.buffer = it.buffer[1:] } + // Advance the lookup to refill the buffer. for len(it.buffer) == 0 { if it.ctx.Err() != nil { @@ -191,17 +186,55 @@ func (it *lookupIterator) Next() bool { } if it.lookup == nil { it.lookup = it.nextLookup(it.ctx) + if it.lookup.empty() { + // If the lookup is empty right after creation, it means the local table + // is in a degraded state, and we need to wait for it to fill again. + it.lookupFailed(it.lookup.tab, 1*time.Minute) + it.lookup = nil + continue + } + // Yield the initial nodes from the iterator before advancing the lookup. + it.buffer = it.lookup.replyBuffer continue } - if !it.lookup.advance() { + + newNodes := it.lookup.advance() + it.buffer = it.lookup.replyBuffer + if !newNodes { it.lookup = nil - continue } - it.buffer = it.lookup.replyBuffer } return true } +// lookupFailed handles failed lookup attempts. This can be called when the table has +// exited, or when it runs out of nodes. +func (it *lookupIterator) lookupFailed(tab *Table, timeout time.Duration) { + tout, cancel := context.WithTimeout(it.ctx, timeout) + defer cancel() + + // Wait for Table initialization to complete, in case it is still in progress. + select { + case <-tab.initDone: + case <-tout.Done(): + return + } + + // Wait for ongoing refresh operation, or trigger one. + if it.tabRefreshing == nil { + it.tabRefreshing = tab.refresh() + } + select { + case <-it.tabRefreshing: + it.tabRefreshing = nil + case <-tout.Done(): + return + } + + // Wait for the table to fill. + tab.waitForNodes(tout, 1) +} + // Close ends the iterator. func (it *lookupIterator) Close() { it.cancel() diff --git a/p2p/discover/table.go b/p2p/discover/table.go index b6c35aaaa94..6a1c7494eeb 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -32,6 +32,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/event" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/p2p/enode" @@ -84,6 +85,7 @@ type Table struct { closeReq chan struct{} closed chan struct{} + nodeFeed event.FeedOf[*enode.Node] nodeAddedHook func(*bucket, *tableNode) nodeRemovedHook func(*bucket, *tableNode) } @@ -567,6 +569,8 @@ func (tab *Table) nodeAdded(b *bucket, n *tableNode) { } n.addedToBucket = time.Now() tab.revalidation.nodeAdded(tab, n) + + tab.nodeFeed.Send(n.Node) if tab.nodeAddedHook != nil { tab.nodeAddedHook(b, n) } @@ -702,3 +706,38 @@ func (tab *Table) deleteNode(n *enode.Node) { b := tab.bucket(n.ID()) tab.deleteInBucket(b, n.ID()) } + +// waitForNodes blocks until the table contains at least n nodes. +func (tab *Table) waitForNodes(ctx context.Context, n int) error { + getlength := func() (count int) { + for _, b := range &tab.buckets { + count += len(b.entries) + } + return count + } + + var ch chan *enode.Node + for { + tab.mutex.Lock() + if getlength() >= n { + tab.mutex.Unlock() + return nil + } + if ch == nil { + // Init subscription. + ch = make(chan *enode.Node) + sub := tab.nodeFeed.Subscribe(ch) + defer sub.Unsubscribe() + } + tab.mutex.Unlock() + + // Wait for a node add event. + select { + case <-ch: + case <-ctx.Done(): + return ctx.Err() + case <-tab.closeReq: + return errClosed + } + } +} diff --git a/p2p/discover/v4_udp_test.go b/p2p/discover/v4_udp_test.go index 1af31f4f1b9..44863183fa9 100644 --- a/p2p/discover/v4_udp_test.go +++ b/p2p/discover/v4_udp_test.go @@ -24,10 +24,12 @@ import ( "errors" "fmt" "io" + "maps" "math/rand" "net" "net/netip" "reflect" + "slices" "sync" "testing" "time" @@ -509,18 +511,26 @@ func TestUDPv4_smallNetConvergence(t *testing.T) { // they have all found each other. status := make(chan error, len(nodes)) for i := range nodes { - node := nodes[i] + self := nodes[i] go func() { - found := make(map[enode.ID]bool, len(nodes)) - it := node.RandomNodes() + missing := make(map[enode.ID]bool, len(nodes)) + for _, n := range nodes { + if n.Self().ID() == self.Self().ID() { + continue // skip self + } + missing[n.Self().ID()] = true + } + + it := self.RandomNodes() for it.Next() { - found[it.Node().ID()] = true - if len(found) == len(nodes) { + delete(missing, it.Node().ID()) + if len(missing) == 0 { status <- nil return } } - status <- fmt.Errorf("node %s didn't find all nodes", node.Self().ID().TerminalString()) + missingIDs := slices.Collect(maps.Keys(missing)) + status <- fmt.Errorf("node %s didn't find all nodes, missing %v", self.Self().ID().TerminalString(), missingIDs) }() } @@ -537,7 +547,6 @@ func TestUDPv4_smallNetConvergence(t *testing.T) { received++ if err != nil { t.Error("ERROR:", err) - return } } }