Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

p2p/discover: fix update logic in handleAddNode #29836

Merged
merged 12 commits into from
May 28, 2024
46 changes: 33 additions & 13 deletions p2p/discover/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,9 @@ func (tab *Table) handleAddNode(req addNodeOp) bool {
}

b := tab.bucket(req.node.ID())
if tab.bumpInBucket(b, req.node.Node) {
// Already in bucket, update record.
n, _ := tab.bumpInBucket(b, req.node.Node, req.isInbound)
if n != nil {
// Already in bucket.
return false
}
if len(b.entries) >= bucketSize {
Expand Down Expand Up @@ -605,26 +606,45 @@ func (tab *Table) deleteInBucket(b *bucket, id enode.ID) *node {
return rep
}

// bumpInBucket updates the node record of n in the bucket.
func (tab *Table) bumpInBucket(b *bucket, newRecord *enode.Node) bool {
// bumpInBucket updates a node record if it exists in the bucket.
// The second return value reports whether the node's endpoint (IP/port) was updated.
func (tab *Table) bumpInBucket(b *bucket, newRecord *enode.Node, isInbound bool) (n *node, endpointChanged bool) {
i := slices.IndexFunc(b.entries, func(elem *node) bool {
return elem.ID() == newRecord.ID()
})
if i == -1 {
return false
return nil, false // not in bucket
}
n = b.entries[i]

// For inbound updates (from the node itself) we accept any change, even if it sets
// back the sequence number. For found nodes (!isInbound), seq has to advance. Note
// this check also ensures found discv4 nodes (which always have seq=0) can't be
// updated.
if newRecord.Seq() <= n.Seq() && !isInbound {
return n, false
}

if !newRecord.IP().Equal(b.entries[i].IP()) {
// Endpoint has changed, ensure that the new IP fits into table limits.
tab.removeIP(b, b.entries[i].IP())
// Check endpoint update against IP limits.
ipchanged := newRecord.IPAddr() != n.IPAddr()
portchanged := newRecord.UDP() != n.UDP()
if ipchanged {
tab.removeIP(b, n.IP())
if !tab.addIP(b, newRecord.IP()) {
// It doesn't, put the previous one back.
tab.addIP(b, b.entries[i].IP())
return false
// It doesn't fit with the limit, put the previous record back.
tab.addIP(b, n.IP())
return n, false
}
}
b.entries[i].Node = newRecord
return true

// Apply update.
n.Node = newRecord
if ipchanged || portchanged {
// Ensure node is revalidated quickly for endpoint changes.
tab.revalidation.nodeEndpointChanged(tab, n)
return n, true
}
return n, false
}

func (tab *Table) handleTrackRequest(op trackRequestOp) {
Expand Down
25 changes: 13 additions & 12 deletions p2p/discover/table_reval.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (

const never = mclock.AbsTime(math.MaxInt64)

const slowRevalidationFactor = 3

// tableRevalidation implements the node revalidation process.
// It tracks all nodes contained in Table, and schedules sending PING to them.
type tableRevalidation struct {
Expand All @@ -48,7 +50,7 @@ func (tr *tableRevalidation) init(cfg *Config) {
tr.fast.interval = cfg.PingInterval
tr.fast.name = "fast"
tr.slow.nextTime = never
tr.slow.interval = cfg.PingInterval * 3
tr.slow.interval = cfg.PingInterval * slowRevalidationFactor
tr.slow.name = "slow"
}

Expand All @@ -65,6 +67,12 @@ func (tr *tableRevalidation) nodeRemoved(n *node) {
n.revalList.remove(n)
}

// nodeEndpointChanged is called when a change in IP or port is detected.
func (tr *tableRevalidation) nodeEndpointChanged(tab *Table, n *node) {
n.isValidatedLive = false
tr.moveToList(&tr.fast, n, tab.cfg.Clock.Now(), &tab.rand)
}

// run performs node revalidation.
// It returns the next time it should be invoked, which is used in the Table main loop
// to schedule a timer. However, run can be called at any time.
Expand Down Expand Up @@ -146,11 +154,11 @@ func (tr *tableRevalidation) handleResponse(tab *Table, resp revalidationRespons
defer tab.mutex.Unlock()

if !resp.didRespond {
// Revalidation failed.
n.livenessChecks /= 3
if n.livenessChecks <= 0 {
tab.deleteInBucket(b, n.ID())
} else {
tab.log.Debug("Node revalidation failed", "b", b.index, "id", n.ID(), "checks", n.livenessChecks, "q", n.revalList.name)
tr.moveToList(&tr.fast, n, now, &tab.rand)
}
return
Expand All @@ -159,22 +167,15 @@ func (tr *tableRevalidation) handleResponse(tab *Table, resp revalidationRespons
// The node responded.
n.livenessChecks++
n.isValidatedLive = true
tab.log.Debug("Node revalidated", "b", b.index, "id", n.ID(), "checks", n.livenessChecks, "q", n.revalList.name)
var endpointChanged bool
if resp.newRecord != nil {
endpointChanged = tab.bumpInBucket(b, resp.newRecord)
if endpointChanged {
// If the node changed its advertised endpoint, the updated ENR is not served
// until it has been revalidated.
n.isValidatedLive = false
}
_, endpointChanged = tab.bumpInBucket(b, resp.newRecord, false)
}
tab.log.Debug("Revalidated node", "b", b.index, "id", n.ID(), "checks", n.livenessChecks, "q", n.revalList)

// Move node over to slow queue after first validation.
// Node moves to slow list if it passed and hasn't changed.
if !endpointChanged {
tr.moveToList(&tr.slow, n, now, &tab.rand)
} else {
tr.moveToList(&tr.fast, n, now, &tab.rand)
}
}

Expand Down
53 changes: 51 additions & 2 deletions p2p/discover/table_reval_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ import (
"time"

"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/enr"
)

// This test checks that revalidation can handle a node disappearing while
// a request is active.
func TestRevalidationNodeRemoved(t *testing.T) {
func TestRevalidation_nodeRemoved(t *testing.T) {
var (
clock mclock.Simulated
transport = newPingRecorder()
Expand All @@ -35,7 +37,7 @@ func TestRevalidationNodeRemoved(t *testing.T) {
)
defer db.Close()

// Fill a bucket.
// Add a node to the table.
node := nodeAtDistance(tab.self().ID(), 255, net.IP{77, 88, 99, 1})
tab.handleAddNode(addNodeOp{node: node})

Expand Down Expand Up @@ -68,3 +70,50 @@ func TestRevalidationNodeRemoved(t *testing.T) {
t.Fatal("removed node contained in revalidation list")
}
}

// This test checks that nodes with an updated endpoint remain in the fast revalidation list.
func TestRevalidation_endpointUpdate(t *testing.T) {
var (
clock mclock.Simulated
transport = newPingRecorder()
tab, db = newInactiveTestTable(transport, Config{Clock: &clock})
tr = &tab.revalidation
)
defer db.Close()

// Add node to table.
node := nodeAtDistance(tab.self().ID(), 255, net.IP{77, 88, 99, 1})
tab.handleAddNode(addNodeOp{node: node})

// Update the record in transport, including endpoint update.
record := node.Record()
record.Set(enr.IP{100, 100, 100, 100})
record.Set(enr.UDP(9999))
nodev2 := enode.SignNull(record, node.ID())
transport.updateRecord(nodev2)

// Start a revalidation request. Schedule once to get the next start time,
// then advance the clock to that point and schedule again to start.
next := tr.run(tab, clock.Now())
clock.Run(time.Duration(next + 1))
tr.run(tab, clock.Now())
if len(tr.activeReq) != 1 {
t.Fatal("revalidation request did not start:", tr.activeReq)
}

// Now finish the revalidation request.
var resp revalidationResponse
select {
case resp = <-tab.revalResponseCh:
case <-time.After(1 * time.Second):
t.Fatal("timed out waiting for revalidation")
}
tr.handleResponse(tab, resp)

if !tr.fast.contains(node.ID()) {
t.Fatal("node not contained in fast revalidation list")
}
if node.isValidatedLive {
t.Fatal("node is marked live after endpoint change")
}
}
122 changes: 90 additions & 32 deletions p2p/discover/table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func waitForRevalidationPing(t *testing.T, transport *pingRecorder, tab *Table,
simclock := tab.cfg.Clock.(*mclock.Simulated)
maxAttempts := tab.len() * 8
for i := 0; i < maxAttempts; i++ {
simclock.Run(tab.cfg.PingInterval)
simclock.Run(tab.cfg.PingInterval * slowRevalidationFactor)
p := transport.waitPing(2 * time.Second)
if p == nil {
t.Fatal("Table did not send revalidation ping")
Expand Down Expand Up @@ -275,7 +275,7 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
return reflect.ValueOf(t)
}

func TestTable_addVerifiedNode(t *testing.T) {
func TestTable_addInboundNode(t *testing.T) {
tab, db := newTestTable(newPingRecorder(), Config{})
<-tab.initDone
defer db.Close()
Expand All @@ -286,29 +286,26 @@ func TestTable_addVerifiedNode(t *testing.T) {
n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2})
tab.addFoundNode(n1)
tab.addFoundNode(n2)
bucket := tab.bucket(n1.ID())
checkBucketContent(t, tab, []*enode.Node{n1.Node, n2.Node})

// Verify bucket content:
bcontent := []*node{n1, n2}
if !reflect.DeepEqual(unwrapNodes(bucket.entries), unwrapNodes(bcontent)) {
t.Fatalf("wrong bucket content: %v", bucket.entries)
}

// Add a changed version of n2.
// Add a changed version of n2. The bucket should be updated.
newrec := n2.Record()
newrec.Set(enr.IP{99, 99, 99, 99})
newn2 := wrapNode(enode.SignNull(newrec, n2.ID()))
tab.addInboundNode(newn2)

// Check that bucket is updated correctly.
newBcontent := []*node{n1, newn2}
if !reflect.DeepEqual(unwrapNodes(bucket.entries), unwrapNodes(newBcontent)) {
t.Fatalf("wrong bucket content after update: %v", bucket.entries)
}
checkIPLimitInvariant(t, tab)
n2v2 := enode.SignNull(newrec, n2.ID())
tab.addInboundNode(wrapNode(n2v2))
checkBucketContent(t, tab, []*enode.Node{n1.Node, n2v2})

// Try updating n2 without sequence number change. The update is accepted
// because it's inbound.
newrec = n2.Record()
newrec.Set(enr.IP{100, 100, 100, 100})
newrec.SetSeq(n2.Seq())
n2v3 := enode.SignNull(newrec, n2.ID())
tab.addInboundNode(wrapNode(n2v3))
checkBucketContent(t, tab, []*enode.Node{n1.Node, n2v3})
}

func TestTable_addSeenNode(t *testing.T) {
func TestTable_addFoundNode(t *testing.T) {
tab, db := newTestTable(newPingRecorder(), Config{})
<-tab.initDone
defer db.Close()
Expand All @@ -319,23 +316,84 @@ func TestTable_addSeenNode(t *testing.T) {
n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2})
tab.addFoundNode(n1)
tab.addFoundNode(n2)
checkBucketContent(t, tab, []*enode.Node{n1.Node, n2.Node})

// Verify bucket content:
bcontent := []*node{n1, n2}
if !reflect.DeepEqual(tab.bucket(n1.ID()).entries, bcontent) {
t.Fatalf("wrong bucket content: %v", tab.bucket(n1.ID()).entries)
}

// Add a changed version of n2.
// Add a changed version of n2. The bucket should be updated.
newrec := n2.Record()
newrec.Set(enr.IP{99, 99, 99, 99})
newn2 := wrapNode(enode.SignNull(newrec, n2.ID()))
tab.addFoundNode(newn2)
n2v2 := enode.SignNull(newrec, n2.ID())
tab.addFoundNode(wrapNode(n2v2))
checkBucketContent(t, tab, []*enode.Node{n1.Node, n2v2})

// Try updating n2 without a sequence number change.
// The update should not be accepted.
newrec = n2.Record()
newrec.Set(enr.IP{100, 100, 100, 100})
newrec.SetSeq(n2.Seq())
n2v3 := enode.SignNull(newrec, n2.ID())
tab.addFoundNode(wrapNode(n2v3))
checkBucketContent(t, tab, []*enode.Node{n1.Node, n2v2})
}

// Check that bucket content is unchanged.
if !reflect.DeepEqual(tab.bucket(n1.ID()).entries, bcontent) {
t.Fatalf("wrong bucket content after update: %v", tab.bucket(n1.ID()).entries)
// This test checks that discv4 nodes can update their own endpoint via PING.
func TestTable_addInboundNodeUpdateV4Accept(t *testing.T) {
tab, db := newTestTable(newPingRecorder(), Config{})
<-tab.initDone
defer db.Close()
defer tab.close()

// Add a v4 node.
key, _ := crypto.HexToECDSA("dd3757a8075e88d0f2b1431e7d3c5b1562e1c0aab9643707e8cbfcc8dae5cfe3")
n1 := enode.NewV4(&key.PublicKey, net.IP{88, 77, 66, 1}, 9000, 9000)
tab.addInboundNode(wrapNode(n1))
checkBucketContent(t, tab, []*enode.Node{n1})

// Add an updated version with changed IP.
// The update will be accepted because it is inbound.
n1v2 := enode.NewV4(&key.PublicKey, net.IP{99, 99, 99, 99}, 9000, 9000)
tab.addInboundNode(wrapNode(n1v2))
checkBucketContent(t, tab, []*enode.Node{n1v2})
}

// This test checks that discv4 node entries will NOT be updated when a
// changed record is found.
func TestTable_addFoundNodeV4UpdateReject(t *testing.T) {
tab, db := newTestTable(newPingRecorder(), Config{})
<-tab.initDone
defer db.Close()
defer tab.close()

// Add a v4 node.
key, _ := crypto.HexToECDSA("dd3757a8075e88d0f2b1431e7d3c5b1562e1c0aab9643707e8cbfcc8dae5cfe3")
n1 := enode.NewV4(&key.PublicKey, net.IP{88, 77, 66, 1}, 9000, 9000)
tab.addFoundNode(wrapNode(n1))
checkBucketContent(t, tab, []*enode.Node{n1})

// Add an updated version with changed IP.
// The update won't be accepted because it isn't inbound.
n1v2 := enode.NewV4(&key.PublicKey, net.IP{99, 99, 99, 99}, 9000, 9000)
tab.addFoundNode(wrapNode(n1v2))
checkBucketContent(t, tab, []*enode.Node{n1})
}

func checkBucketContent(t *testing.T, tab *Table, nodes []*enode.Node) {
t.Helper()

b := tab.bucket(nodes[0].ID())
if reflect.DeepEqual(unwrapNodes(b.entries), nodes) {
return
}
t.Log("wrong bucket content. have nodes:")
for _, n := range b.entries {
t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IP())
}
t.Log("want nodes:")
for _, n := range nodes {
t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IP())
}
t.FailNow()

// Also check IP limits.
checkIPLimitInvariant(t, tab)
}

Expand Down
Loading