diff --git a/limiter.go b/limiter.go index 3e20976b..5cd8cfa2 100644 --- a/limiter.go +++ b/limiter.go @@ -42,10 +42,9 @@ func (dj *dialJob) dialTimeout() time.Duration { type dialLimiter struct { lk sync.Mutex - isFdConsumingFnc isFdConsumingFnc - fdConsuming int - fdLimit int - waitingOnFd []*dialJob + fdConsuming int + fdLimit int + waitingOnFd []*dialJob dialFunc dialfunc @@ -55,21 +54,19 @@ type dialLimiter struct { } type dialfunc func(context.Context, peer.ID, ma.Multiaddr) (transport.CapableConn, error) -type isFdConsumingFnc func(ma.Multiaddr) bool -func newDialLimiter(df dialfunc, fdFnc isFdConsumingFnc) *dialLimiter { +func newDialLimiter(df dialfunc) *dialLimiter { fd := ConcurrentFdDials if env := os.Getenv("LIBP2P_SWARM_FD_LIMIT"); env != "" { if n, err := strconv.ParseInt(env, 10, 32); err == nil { fd = int(n) } } - return newDialLimiterWithParams(fdFnc, df, fd, DefaultPerPeerRateLimit) + return newDialLimiterWithParams(df, fd, DefaultPerPeerRateLimit) } -func newDialLimiterWithParams(isFdConsumingFnc isFdConsumingFnc, df dialfunc, fdLimit, perPeerLimit int) *dialLimiter { +func newDialLimiterWithParams(df dialfunc, fdLimit, perPeerLimit int) *dialLimiter { return &dialLimiter{ - isFdConsumingFnc: isFdConsumingFnc, fdLimit: fdLimit, perPeerLimit: perPeerLimit, waitingOnPeerLimit: make(map[peer.ID][]*dialJob), @@ -157,7 +154,7 @@ func (dl *dialLimiter) shouldConsumeFd(addr ma.Multiaddr) bool { isRelay := err == nil - return !isRelay && dl.isFdConsumingFnc(addr) + return !isRelay && isFdConsumingAddr(addr) } func (dl *dialLimiter) addCheckFdLimit(dj *dialJob) { diff --git a/limiter_test.go b/limiter_test.go index 1aefffec..4a592b86 100644 --- a/limiter_test.go +++ b/limiter_test.go @@ -18,19 +18,6 @@ import ( mafmt "github.com/multiformats/go-multiaddr-fmt" ) -var isFdConsuming = func(addr ma.Multiaddr) bool { - res := false - - ma.ForEach(addr, func(c ma.Component) bool { - if c.Protocol().Code == ma.P_TCP { - res = true - return false - } - return true - }) - return res -} - func mustAddr(t *testing.T, s string) ma.Multiaddr { a, err := ma.NewMultiaddr(s) if err != nil { @@ -95,7 +82,7 @@ func TestLimiterBasicDials(t *testing.T) { hang := make(chan struct{}) defer close(hang) - l := newDialLimiterWithParams(isFdConsuming, hangDialFunc(hang), ConcurrentFdDials, 4) + l := newDialLimiterWithParams(hangDialFunc(hang), ConcurrentFdDials, 4) bads := []ma.Multiaddr{addrWithPort(t, 1), addrWithPort(t, 2), addrWithPort(t, 3), addrWithPort(t, 4)} good := addrWithPort(t, 20) @@ -144,7 +131,7 @@ func TestLimiterBasicDials(t *testing.T) { func TestFDLimiting(t *testing.T) { hang := make(chan struct{}) defer close(hang) - l := newDialLimiterWithParams(isFdConsuming, hangDialFunc(hang), 16, 5) + l := newDialLimiterWithParams(hangDialFunc(hang), 16, 5) bads := []ma.Multiaddr{addrWithPort(t, 1), addrWithPort(t, 2), addrWithPort(t, 3), addrWithPort(t, 4)} pids := []peer.ID{"testpeer1", "testpeer2", "testpeer3", "testpeer4"} @@ -220,7 +207,7 @@ func TestTokenRedistribution(t *testing.T) { <-ch return nil, fmt.Errorf("test bad dial") } - l := newDialLimiterWithParams(isFdConsuming, df, 8, 4) + l := newDialLimiterWithParams(df, 8, 4) bads := []ma.Multiaddr{addrWithPort(t, 1), addrWithPort(t, 2), addrWithPort(t, 3), addrWithPort(t, 4)} pids := []peer.ID{"testpeer1", "testpeer2"} @@ -313,7 +300,7 @@ func TestStressLimiter(t *testing.T) { return nil, fmt.Errorf("test bad dial") } - l := newDialLimiterWithParams(isFdConsuming, df, 20, 5) + l := newDialLimiterWithParams(df, 20, 5) var bads []ma.Multiaddr for i := 0; i < 100; i++ { @@ -367,7 +354,7 @@ func TestFDLimitUnderflow(t *testing.T) { } const fdLimit = 20 - l := newDialLimiterWithParams(isFdConsuming, df, fdLimit, 3) + l := newDialLimiterWithParams(df, fdLimit, 3) var addrs []ma.Multiaddr for i := 0; i <= 1000; i++ { diff --git a/swarm.go b/swarm.go index e975548f..35eb4156 100644 --- a/swarm.go +++ b/swarm.go @@ -123,7 +123,7 @@ func NewSwarm(ctx context.Context, local peer.ID, peers peerstore.Peerstore, bwc } s.dsync = newDialSync(s.startDialWorker) - s.limiter = newDialLimiter(s.dialAddr, isFdConsumingAddr) + s.limiter = newDialLimiter(s.dialAddr) s.proc = goprocessctx.WithContext(ctx) s.ctx = goprocessctx.OnClosingContext(s.proc) s.backf.init(s.ctx)