diff --git a/discv5/crawler.go b/discv5/crawler.go index 04fb750..e719efd 100644 --- a/discv5/crawler.go +++ b/discv5/crawler.go @@ -504,26 +504,22 @@ func (c *Crawler) crawlDiscV5(ctx context.Context, pi PeerInfo) chan DiscV5Resul } result.DoneAt = time.Now() - result.Error = err + // if we have at least a successful result, don't record error + if noSuccessfulRequest(err, errorBits.Load()) { + result.Error = err + } result.RoutingTable = &core.RoutingTable[PeerInfo]{ PeerID: pi.ID(), Neighbors: []PeerInfo{}, ErrorBits: uint16(errorBits.Load()), - Error: err, + Error: result.Error, } for _, n := range allNeighbors { result.RoutingTable.Neighbors = append(result.RoutingTable.Neighbors, n) } - // if we have at least a successful result, delete error - // bitwise operation checks whether errorBits is a power of 2 minus 1, - // if not, then there was at least one successful result - if result.Error != nil && (result.RoutingTable.ErrorBits&(result.RoutingTable.ErrorBits+1)) == 0 { - result.Error = nil - } - // if there was a connection error, parse it to a known one if result.Error != nil { result.ErrorStr = db.NetError(result.Error) @@ -539,3 +535,16 @@ func (c *Crawler) crawlDiscV5(ctx context.Context, pi PeerInfo) chan DiscV5Resul return resultCh } + +// noSuccessfulRequest returns true if the given error is non nil, and all bits +// of the given errorBits are set. This means that no successful request has +// been made. This is equivalent to verifying that all righmost bits are equal +// to 1, or that the errorBits is a power of 2 minus 1. +// +// Examples: +// 0b00000011 -> true +// 0b00000111 -> true +// 0b00001101 -> false +func noSuccessfulRequest(err error, errorBits uint32) bool { + return err != nil && errorBits&(errorBits+1) == 0 +} diff --git a/discv5/crawler_test.go b/discv5/crawler_test.go index a3cd44d..b803ba4 100644 --- a/discv5/crawler_test.go +++ b/discv5/crawler_test.go @@ -1,10 +1,14 @@ package discv5 import ( + "errors" + "fmt" + "math" "testing" ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/dennis-tra/nebula-crawler/nebtest" ) @@ -98,3 +102,90 @@ func Test_sanitizeAddrs(t *testing.T) { }) } } + +func TestNoSuccessfulRequest(t *testing.T) { + tests := []struct { + name string + err error + errorBits uint32 + want bool + }{ + { + name: "no err", + err: nil, + errorBits: 0b00000000, + want: false, + }, + { + name: "first failed", + err: fmt.Errorf("some err"), + errorBits: 0b00000001, + want: true, + }, + { + name: "second failed, first worked", + err: fmt.Errorf("some err"), + errorBits: 0b00000010, + want: false, + }, + { + name: "all four failed", + err: fmt.Errorf("some err"), + errorBits: 0b00001111, + want: true, + }, + { + name: "seven failed, one worked", + err: fmt.Errorf("some err"), + errorBits: 0b11110111, + want: false, + }, + { + name: "eight failed, but the last one overflowing succeeded (no error)", + err: nil, + errorBits: 0b11111111, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := noSuccessfulRequest(tt.err, tt.errorBits); got != tt.want { + t.Errorf("noSuccessfulRequest() = %v, want %v, %s", got, tt.want, tt.name) + } + }) + } + + // fail if err is nil + require.False(t, noSuccessfulRequest(nil, 0)) + require.False(t, noSuccessfulRequest(nil, 0b11111111)) + + err := errors.New("error") + + // list of numbers that are power of two minus one + // for which noSuccessfulRequest should return true + // because all bits are set (all failures = no success) + powerOfTwoMinusOneList := []uint32{ + 0b00000000, + 0b00000001, + 0b00000011, + 0b00000111, + 0b00001111, + 0b00011111, + 0b00111111, + 0b01111111, + 0b11111111, + } + + for i := uint32(0); i < uint32(math.Pow(2, 8)); i++ { + powerOfTwoMinusOne := false + for _, v := range powerOfTwoMinusOneList { + if i == v { + powerOfTwoMinusOne = true + break + } + } + // assert that noSuccessfulRequest returns true if and only if + // all bits are set + require.Equal(t, powerOfTwoMinusOne, noSuccessfulRequest(err, i)) + } +}