Skip to content

Commit

Permalink
Merge pull request #16 from libp2p/shutdown
Browse files Browse the repository at this point in the history
implement a clean shutdown of the probe method
  • Loading branch information
marten-seemann committed Sep 27, 2021
2 parents a3c4995 + f91e31c commit f9db021
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 33 deletions.
75 changes: 42 additions & 33 deletions server.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package zeroconf

import (
"errors"
"fmt"
"log"
"math/rand"
Expand Down Expand Up @@ -75,8 +74,7 @@ func Register(instance, service, domain string, port int, text []string, ifaces
}

s.service = entry
go s.mainloop()
go s.probe()
s.start()

return s, nil
}
Expand Down Expand Up @@ -132,8 +130,7 @@ func RegisterProxy(instance, service, domain string, port int, host string, ips
}

s.service = entry
go s.mainloop()
go s.probe()
s.start()

return s, nil
}
Expand All @@ -151,7 +148,7 @@ type Server struct {

shouldShutdown chan struct{}
shutdownLock sync.Mutex
shutdownEnd sync.WaitGroup
refCount sync.WaitGroup
isShutdown bool
ttl uint32
}
Expand Down Expand Up @@ -182,19 +179,17 @@ func newServer(ifaces []net.Interface) (*Server, error) {
return s, nil
}

// Start listeners and waits for the shutdown signal from exit channel
func (s *Server) mainloop() {
func (s *Server) start() {
if s.ipv4conn != nil {
s.refCount.Add(1)
go s.recv4(s.ipv4conn)
}
if s.ipv6conn != nil {
s.refCount.Add(1)
go s.recv6(s.ipv6conn)
}
}

// Shutdown closes all udp connections and unregisters the service
func (s *Server) Shutdown() {
s.shutdown()
s.refCount.Add(1)
go s.probe()
}

// SetText updates and announces the TXT records
Expand All @@ -208,15 +203,17 @@ func (s *Server) TTL(ttl uint32) {
s.ttl = ttl
}

// Shutdown server will close currently open connections & channel
func (s *Server) shutdown() error {
// Shutdown closes all udp connections and unregisters the service
func (s *Server) Shutdown() {
s.shutdownLock.Lock()
defer s.shutdownLock.Unlock()
if s.isShutdown {
return errors.New("server is already shutdown")
return
}

err := s.unregister()
if err := s.unregister(); err != nil {
log.Printf("failed to unregister: %s", err)
}

close(s.shouldShutdown)

Expand All @@ -228,20 +225,17 @@ func (s *Server) shutdown() error {
}

// Wait for connection and routines to be closed
s.shutdownEnd.Wait()
s.refCount.Wait()
s.isShutdown = true

return err
}

// recv is a long running routine to receive packets from an interface
// recv4 is a long running routine to receive packets from an interface
func (s *Server) recv4(c *ipv4.PacketConn) {
defer s.refCount.Done()
if c == nil {
return
}
buf := make([]byte, 65536)
s.shutdownEnd.Add(1)
defer s.shutdownEnd.Done()
for {
select {
case <-s.shouldShutdown:
Expand All @@ -260,14 +254,13 @@ func (s *Server) recv4(c *ipv4.PacketConn) {
}
}

// recv is a long running routine to receive packets from an interface
// recv6 is a long running routine to receive packets from an interface
func (s *Server) recv6(c *ipv6.PacketConn) {
defer s.refCount.Done()
if c == nil {
return
}
buf := make([]byte, 65536)
s.shutdownEnd.Add(1)
defer s.shutdownEnd.Done()
for {
select {
case <-s.shouldShutdown:
Expand Down Expand Up @@ -528,6 +521,8 @@ func (s *Server) serviceTypeName(resp *dns.Msg, ttl uint32) {
// Perform probing & announcement
//TODO: implement a proper probing & conflict resolution
func (s *Server) probe() {
defer s.refCount.Done()

q := new(dns.Msg)
q.SetQuestion(s.service.ServiceInstanceName(), dns.TypePTR)
q.RecursionDesired = false
Expand Down Expand Up @@ -555,16 +550,25 @@ func (s *Server) probe() {
}
q.Ns = []dns.RR{srv, txt}

randomizer := rand.New(rand.NewSource(time.Now().UnixNano()))

// Wait for a random duration uniformly distributed between 0 and 250 ms
// before sending the first probe packet.
time.Sleep(time.Duration(randomizer.Intn(250)) * time.Millisecond)
timer := time.NewTimer(time.Duration(rand.Intn(250)) * time.Millisecond)
defer timer.Stop()
select {
case <-timer.C:
case <-s.shouldShutdown:
return
}
for i := 0; i < 3; i++ {
if err := s.multicastResponse(q, 0); err != nil {
log.Println("[ERR] zeroconf: failed to send probe:", err.Error())
}
time.Sleep(250 * time.Millisecond)
timer.Reset(250 * time.Millisecond)
select {
case <-timer.C:
case <-s.shouldShutdown:
return
}
}

// From RFC6762
Expand All @@ -573,7 +577,7 @@ func (s *Server) probe() {
// packet loss, a responder MAY send up to eight unsolicited responses,
// provided that the interval between unsolicited responses increases by
// at least a factor of two with every response sent.
timeout := 1 * time.Second
timeout := time.Second
for i := 0; i < multicastRepetitions; i++ {
for _, intf := range s.ifaces {
resp := new(dns.Msg)
Expand All @@ -587,7 +591,12 @@ func (s *Server) probe() {
log.Println("[ERR] zeroconf: failed to send announcement:", err.Error())
}
}
time.Sleep(timeout)
timer.Reset(timeout)
select {
case <-timer.C:
case <-s.shouldShutdown:
return
}
timeout *= 2
}
}
Expand Down Expand Up @@ -719,7 +728,7 @@ func (s *Server) unicastResponse(resp *dns.Msg, ifIndex int, from net.Addr) erro
}
}

// multicastResponse us used to send a multicast response packet
// multicastResponse is used to send a multicast response packet
func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int) error {
buf, err := msg.Pack()
if err != nil {
Expand Down
18 changes: 18 additions & 0 deletions service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,24 @@ func startMDNS(t *testing.T, port int, name, service, domain string) {
log.Printf("Published service: %s, type: %s, domain: %s", name, service, domain)
}

func TestQuickShutdown(t *testing.T) {
server, err := Register(mdnsName, mdnsService, mdnsDomain, mdnsPort, []string{"txtv=0", "lo=1", "la=2"}, nil)
if err != nil {
t.Fatal(err)
}

done := make(chan struct{})
go func() {
defer close(done)
server.Shutdown()
}()
select {
case <-done:
case <-time.After(500 * time.Millisecond):
t.Fatal("shutdown took longer than 500ms")
}
}

func TestBasic(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
Expand Down

0 comments on commit f9db021

Please sign in to comment.