diff --git a/p2p/net/nat/mapping.go b/p2p/net/nat/mapping.go index 1835ed9a77..f9b508e4e2 100644 --- a/p2p/net/nat/mapping.go +++ b/p2p/net/nat/mapping.go @@ -5,8 +5,6 @@ import ( "net" "sync" "time" - - "github.com/jbenet/goprocess" ) // Mapping represents a port mapping in a NAT. @@ -42,7 +40,6 @@ type mapping struct { proto string intport int extport int - proc goprocess.Process cached net.IP cacheTime time.Time @@ -117,5 +114,6 @@ func (m *mapping) ExternalAddr() (net.Addr, error) { } func (m *mapping) Close() error { - return m.proc.Close() + m.nat.removeMapping(m) + return nil } diff --git a/p2p/net/nat/nat.go b/p2p/net/nat/nat.go index a29b5a9c1f..dad3226b46 100644 --- a/p2p/net/nat/nat.go +++ b/p2p/net/nat/nat.go @@ -8,16 +8,13 @@ import ( "time" logging "github.com/ipfs/go-log/v2" - goprocess "github.com/jbenet/goprocess" - periodic "github.com/jbenet/goprocess/periodic" - nat "github.com/libp2p/go-nat" -) -var ( - // ErrNoMapping signals no mapping exists for an address - ErrNoMapping = errors.New("mapping not established") + "github.com/libp2p/go-nat" ) +// ErrNoMapping signals no mapping exists for an address +var ErrNoMapping = errors.New("mapping not established") + var log = logging.Logger("nat") // MappingDuration is a default port mapping duration. @@ -30,24 +27,7 @@ const CacheTime = time.Second * 15 // DiscoverNAT looks for a NAT device in the network and // returns an object that can manage port mappings. func DiscoverNAT(ctx context.Context) (*NAT, error) { - var ( - natInstance nat.NAT - err error - ) - - done := make(chan struct{}) - go func() { - defer close(done) - // This will abort in 10 seconds anyways. - natInstance, err = nat.DiscoverGateway() - }() - - select { - case <-done: - case <-ctx.Done(): - return nil, ctx.Err() - } - + natInstance, err := nat.DiscoverGateway(ctx) if err != nil { return nil, err } @@ -70,29 +50,35 @@ func DiscoverNAT(ctx context.Context) (*NAT, error) { type NAT struct { natmu sync.Mutex nat nat.NAT - proc goprocess.Process + + refCount sync.WaitGroup + ctx context.Context + ctxCancel context.CancelFunc mappingmu sync.RWMutex // guards mappings + closed bool mappings map[*mapping]struct{} } func newNAT(realNAT nat.NAT) *NAT { + ctx, cancel := context.WithCancel(context.Background()) return &NAT{ - nat: realNAT, - proc: goprocess.WithParent(goprocess.Background()), - mappings: make(map[*mapping]struct{}), + nat: realNAT, + mappings: make(map[*mapping]struct{}), + ctx: ctx, + ctxCancel: cancel, } } // Close shuts down all port mappings. NAT can no longer be used. func (nat *NAT) Close() error { - return nat.proc.Close() -} + nat.mappingmu.Lock() + nat.closed = true + nat.mappingmu.Unlock() -// Process returns the nat's life-cycle manager, for making it listen -// to close signals. -func (nat *NAT) Process() goprocess.Process { - return nat.proc + nat.ctxCancel() + nat.refCount.Wait() + return nil } // Mappings returns a slice of all NAT mappings @@ -106,21 +92,6 @@ func (nat *NAT) Mappings() []Mapping { return maps2 } -func (nat *NAT) addMapping(m *mapping) { - // make mapping automatically close when nat is closed. - nat.proc.AddChild(m.proc) - - nat.mappingmu.Lock() - nat.mappings[m] = struct{}{} - nat.mappingmu.Unlock() -} - -func (nat *NAT) rmMapping(m *mapping) { - nat.mappingmu.Lock() - delete(nat.mappings, m) - nat.mappingmu.Unlock() -} - // NewMapping attempts to construct a mapping on protocol and internal port // It will also periodically renew the mapping until the returned Mapping // -- or its parent NAT -- is Closed. @@ -146,19 +117,15 @@ func (nat *NAT) NewMapping(protocol string, port int) (Mapping, error) { proto: protocol, } - m.proc = goprocess.WithTeardown(func() error { - nat.rmMapping(m) - nat.natmu.Lock() - defer nat.natmu.Unlock() - nat.nat.DeletePortMapping(m.Protocol(), m.InternalPort()) - return nil - }) - - nat.addMapping(m) - - m.proc.AddChild(periodic.Every(MappingDuration/3, func(worker goprocess.Process) { - nat.establishMapping(m) - })) + nat.mappingmu.Lock() + if nat.closed { + nat.mappingmu.Unlock() + return nil, errors.New("closed") + } + nat.mappings[m] = struct{}{} + nat.refCount.Add(1) + nat.mappingmu.Unlock() + go nat.refreshMappings(m) // do it once synchronously, so first mapping is done right away, and before exiting, // allowing users -- in the optimistic case -- to use results right after. @@ -166,11 +133,36 @@ func (nat *NAT) NewMapping(protocol string, port int) (Mapping, error) { return m, nil } +func (nat *NAT) removeMapping(m *mapping) { + nat.mappingmu.Lock() + delete(nat.mappings, m) + nat.mappingmu.Unlock() + nat.natmu.Lock() + nat.nat.DeletePortMapping(m.Protocol(), m.InternalPort()) + nat.natmu.Unlock() +} + +func (nat *NAT) refreshMappings(m *mapping) { + defer nat.refCount.Done() + t := time.NewTicker(MappingDuration / 3) + defer t.Stop() + + for { + select { + case <-t.C: + nat.establishMapping(m) + case <-nat.ctx.Done(): + m.Close() + return + } + } +} + func (nat *NAT) establishMapping(m *mapping) { oldport := m.ExternalPort() log.Debugf("Attempting port map: %s/%d", m.Protocol(), m.InternalPort()) - comment := "libp2p" + const comment = "libp2p" nat.natmu.Lock() newport, err := nat.nat.AddPortMapping(m.Protocol(), m.InternalPort(), comment, MappingDuration)