From 73a22d526c8143c21e134be4574cd2275495fa90 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 20 Sep 2021 08:55:34 +0100 Subject: [PATCH 1/3] fix incrementing of sync.WaitGroup Incrementing needs to be done before the corresponding go routine is started. --- server.go | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/server.go b/server.go index 26c96ca1..f985e097 100644 --- a/server.go +++ b/server.go @@ -75,8 +75,7 @@ func Register(instance, service, domain string, port int, text []string, ifaces } s.service = entry - go s.mainloop() - go s.probe() + s.start() return s, nil } @@ -132,8 +131,7 @@ func RegisterProxy(instance, service, domain string, port int, host string, ips } s.service = entry - go s.mainloop() - go s.probe() + s.start() return s, nil } @@ -151,7 +149,7 @@ type Server struct { shouldShutdown chan struct{} shutdownLock sync.Mutex - shutdownEnd sync.WaitGroup + refCount sync.WaitGroup isShutdown bool ttl uint32 } @@ -182,14 +180,16 @@ func newServer(ifaces []net.Interface) (*Server, error) { return s, nil } -// Start listeners and waits for the shutdown signal from exit channel -func (s *Server) mainloop() { +func (s *Server) start() { if s.ipv4conn != nil { + s.refCount.Add(1) go s.recv4(s.ipv4conn) } if s.ipv6conn != nil { + s.refCount.Add(1) go s.recv6(s.ipv6conn) } + go s.probe() } // Shutdown closes all udp connections and unregisters the service @@ -228,20 +228,19 @@ func (s *Server) shutdown() error { } // Wait for connection and routines to be closed - s.shutdownEnd.Wait() + s.refCount.Wait() s.isShutdown = true return err } -// recv is a long running routine to receive packets from an interface +// recv4 is a long running routine to receive packets from an interface func (s *Server) recv4(c *ipv4.PacketConn) { + defer s.refCount.Done() if c == nil { return } buf := make([]byte, 65536) - s.shutdownEnd.Add(1) - defer s.shutdownEnd.Done() for { select { case <-s.shouldShutdown: @@ -260,14 +259,13 @@ func (s *Server) recv4(c *ipv4.PacketConn) { } } -// recv is a long running routine to receive packets from an interface +// recv6 is a long running routine to receive packets from an interface func (s *Server) recv6(c *ipv6.PacketConn) { + defer s.refCount.Done() if c == nil { return } buf := make([]byte, 65536) - s.shutdownEnd.Add(1) - defer s.shutdownEnd.Done() for { select { case <-s.shouldShutdown: From bed14e1c2f0d79ec2e4283c46d16f1711af60574 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 20 Sep 2021 08:58:52 +0100 Subject: [PATCH 2/3] simplify random number generation --- server.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/server.go b/server.go index f985e097..303383f2 100644 --- a/server.go +++ b/server.go @@ -553,11 +553,9 @@ func (s *Server) probe() { } q.Ns = []dns.RR{srv, txt} - randomizer := rand.New(rand.NewSource(time.Now().UnixNano())) - // Wait for a random duration uniformly distributed between 0 and 250 ms // before sending the first probe packet. - time.Sleep(time.Duration(randomizer.Intn(250)) * time.Millisecond) + time.Sleep(time.Duration(rand.Intn(250)) * time.Millisecond) for i := 0; i < 3; i++ { if err := s.multicastResponse(q, 0); err != nil { log.Println("[ERR] zeroconf: failed to send probe:", err.Error()) From f91e31cf0a9f8c0e313357d6f0b52cd8b5d03cb2 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 20 Sep 2021 09:33:43 +0100 Subject: [PATCH 3/3] use timers instead of sleeps when probing --- server.go | 47 ++++++++++++++++++++++++++++++----------------- service_test.go | 18 ++++++++++++++++++ 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/server.go b/server.go index 303383f2..ab1b3486 100644 --- a/server.go +++ b/server.go @@ -1,7 +1,6 @@ package zeroconf import ( - "errors" "fmt" "log" "math/rand" @@ -189,14 +188,10 @@ func (s *Server) start() { s.refCount.Add(1) go s.recv6(s.ipv6conn) } + s.refCount.Add(1) go s.probe() } -// Shutdown closes all udp connections and unregisters the service -func (s *Server) Shutdown() { - s.shutdown() -} - // SetText updates and announces the TXT records func (s *Server) SetText(text []string) { s.service.Text = text @@ -208,15 +203,17 @@ func (s *Server) TTL(ttl uint32) { s.ttl = ttl } -// Shutdown server will close currently open connections & channel -func (s *Server) shutdown() error { +// Shutdown closes all udp connections and unregisters the service +func (s *Server) Shutdown() { s.shutdownLock.Lock() defer s.shutdownLock.Unlock() if s.isShutdown { - return errors.New("server is already shutdown") + return } - err := s.unregister() + if err := s.unregister(); err != nil { + log.Printf("failed to unregister: %s", err) + } close(s.shouldShutdown) @@ -230,8 +227,6 @@ func (s *Server) shutdown() error { // Wait for connection and routines to be closed s.refCount.Wait() s.isShutdown = true - - return err } // recv4 is a long running routine to receive packets from an interface @@ -526,6 +521,8 @@ func (s *Server) serviceTypeName(resp *dns.Msg, ttl uint32) { // Perform probing & announcement //TODO: implement a proper probing & conflict resolution func (s *Server) probe() { + defer s.refCount.Done() + q := new(dns.Msg) q.SetQuestion(s.service.ServiceInstanceName(), dns.TypePTR) q.RecursionDesired = false @@ -555,12 +552,23 @@ func (s *Server) probe() { // Wait for a random duration uniformly distributed between 0 and 250 ms // before sending the first probe packet. - time.Sleep(time.Duration(rand.Intn(250)) * time.Millisecond) + timer := time.NewTimer(time.Duration(rand.Intn(250)) * time.Millisecond) + defer timer.Stop() + select { + case <-timer.C: + case <-s.shouldShutdown: + return + } for i := 0; i < 3; i++ { if err := s.multicastResponse(q, 0); err != nil { log.Println("[ERR] zeroconf: failed to send probe:", err.Error()) } - time.Sleep(250 * time.Millisecond) + timer.Reset(250 * time.Millisecond) + select { + case <-timer.C: + case <-s.shouldShutdown: + return + } } // From RFC6762 @@ -569,7 +577,7 @@ func (s *Server) probe() { // packet loss, a responder MAY send up to eight unsolicited responses, // provided that the interval between unsolicited responses increases by // at least a factor of two with every response sent. - timeout := 1 * time.Second + timeout := time.Second for i := 0; i < multicastRepetitions; i++ { for _, intf := range s.ifaces { resp := new(dns.Msg) @@ -583,7 +591,12 @@ func (s *Server) probe() { log.Println("[ERR] zeroconf: failed to send announcement:", err.Error()) } } - time.Sleep(timeout) + timer.Reset(timeout) + select { + case <-timer.C: + case <-s.shouldShutdown: + return + } timeout *= 2 } } @@ -715,7 +728,7 @@ func (s *Server) unicastResponse(resp *dns.Msg, ifIndex int, from net.Addr) erro } } -// multicastResponse us used to send a multicast response packet +// multicastResponse is used to send a multicast response packet func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int) error { buf, err := msg.Pack() if err != nil { diff --git a/service_test.go b/service_test.go index 35afe8d5..14f56761 100644 --- a/service_test.go +++ b/service_test.go @@ -25,6 +25,24 @@ func startMDNS(t *testing.T, port int, name, service, domain string) { log.Printf("Published service: %s, type: %s, domain: %s", name, service, domain) } +func TestQuickShutdown(t *testing.T) { + server, err := Register(mdnsName, mdnsService, mdnsDomain, mdnsPort, []string{"txtv=0", "lo=1", "la=2"}, nil) + if err != nil { + t.Fatal(err) + } + + done := make(chan struct{}) + go func() { + defer close(done) + server.Shutdown() + }() + select { + case <-done: + case <-time.After(500 * time.Millisecond): + t.Fatal("shutdown took longer than 500ms") + } +} + func TestBasic(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel()