diff --git a/server.go b/server.go index 303383f2..ab1b3486 100644 --- a/server.go +++ b/server.go @@ -1,7 +1,6 @@ package zeroconf import ( - "errors" "fmt" "log" "math/rand" @@ -189,14 +188,10 @@ func (s *Server) start() { s.refCount.Add(1) go s.recv6(s.ipv6conn) } + s.refCount.Add(1) go s.probe() } -// Shutdown closes all udp connections and unregisters the service -func (s *Server) Shutdown() { - s.shutdown() -} - // SetText updates and announces the TXT records func (s *Server) SetText(text []string) { s.service.Text = text @@ -208,15 +203,17 @@ func (s *Server) TTL(ttl uint32) { s.ttl = ttl } -// Shutdown server will close currently open connections & channel -func (s *Server) shutdown() error { +// Shutdown closes all udp connections and unregisters the service +func (s *Server) Shutdown() { s.shutdownLock.Lock() defer s.shutdownLock.Unlock() if s.isShutdown { - return errors.New("server is already shutdown") + return } - err := s.unregister() + if err := s.unregister(); err != nil { + log.Printf("failed to unregister: %s", err) + } close(s.shouldShutdown) @@ -230,8 +227,6 @@ func (s *Server) shutdown() error { // Wait for connection and routines to be closed s.refCount.Wait() s.isShutdown = true - - return err } // recv4 is a long running routine to receive packets from an interface @@ -526,6 +521,8 @@ func (s *Server) serviceTypeName(resp *dns.Msg, ttl uint32) { // Perform probing & announcement //TODO: implement a proper probing & conflict resolution func (s *Server) probe() { + defer s.refCount.Done() + q := new(dns.Msg) q.SetQuestion(s.service.ServiceInstanceName(), dns.TypePTR) q.RecursionDesired = false @@ -555,12 +552,23 @@ func (s *Server) probe() { // Wait for a random duration uniformly distributed between 0 and 250 ms // before sending the first probe packet. - time.Sleep(time.Duration(rand.Intn(250)) * time.Millisecond) + timer := time.NewTimer(time.Duration(rand.Intn(250)) * time.Millisecond) + defer timer.Stop() + select { + case <-timer.C: + case <-s.shouldShutdown: + return + } for i := 0; i < 3; i++ { if err := s.multicastResponse(q, 0); err != nil { log.Println("[ERR] zeroconf: failed to send probe:", err.Error()) } - time.Sleep(250 * time.Millisecond) + timer.Reset(250 * time.Millisecond) + select { + case <-timer.C: + case <-s.shouldShutdown: + return + } } // From RFC6762 @@ -569,7 +577,7 @@ func (s *Server) probe() { // packet loss, a responder MAY send up to eight unsolicited responses, // provided that the interval between unsolicited responses increases by // at least a factor of two with every response sent. - timeout := 1 * time.Second + timeout := time.Second for i := 0; i < multicastRepetitions; i++ { for _, intf := range s.ifaces { resp := new(dns.Msg) @@ -583,7 +591,12 @@ func (s *Server) probe() { log.Println("[ERR] zeroconf: failed to send announcement:", err.Error()) } } - time.Sleep(timeout) + timer.Reset(timeout) + select { + case <-timer.C: + case <-s.shouldShutdown: + return + } timeout *= 2 } } @@ -715,7 +728,7 @@ func (s *Server) unicastResponse(resp *dns.Msg, ifIndex int, from net.Addr) erro } } -// multicastResponse us used to send a multicast response packet +// multicastResponse is used to send a multicast response packet func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int) error { buf, err := msg.Pack() if err != nil { diff --git a/service_test.go b/service_test.go index 35afe8d5..14f56761 100644 --- a/service_test.go +++ b/service_test.go @@ -25,6 +25,24 @@ func startMDNS(t *testing.T, port int, name, service, domain string) { log.Printf("Published service: %s, type: %s, domain: %s", name, service, domain) } +func TestQuickShutdown(t *testing.T) { + server, err := Register(mdnsName, mdnsService, mdnsDomain, mdnsPort, []string{"txtv=0", "lo=1", "la=2"}, nil) + if err != nil { + t.Fatal(err) + } + + done := make(chan struct{}) + go func() { + defer close(done) + server.Shutdown() + }() + select { + case <-done: + case <-time.After(500 * time.Millisecond): + t.Fatal("shutdown took longer than 500ms") + } +} + func TestBasic(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel()