Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement a clean shutdown of the probe method #16

Merged
merged 3 commits into from
Sep 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 42 additions & 33 deletions server.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package zeroconf

import (
"errors"
"fmt"
"log"
"math/rand"
Expand Down Expand Up @@ -75,8 +74,7 @@ func Register(instance, service, domain string, port int, text []string, ifaces
}

s.service = entry
go s.mainloop()
go s.probe()
s.start()

return s, nil
}
Expand Down Expand Up @@ -132,8 +130,7 @@ func RegisterProxy(instance, service, domain string, port int, host string, ips
}

s.service = entry
go s.mainloop()
go s.probe()
s.start()

return s, nil
}
Expand All @@ -151,7 +148,7 @@ type Server struct {

shouldShutdown chan struct{}
shutdownLock sync.Mutex
shutdownEnd sync.WaitGroup
refCount sync.WaitGroup
isShutdown bool
ttl uint32
}
Expand Down Expand Up @@ -182,19 +179,17 @@ func newServer(ifaces []net.Interface) (*Server, error) {
return s, nil
}

// Start listeners and waits for the shutdown signal from exit channel
func (s *Server) mainloop() {
func (s *Server) start() {
if s.ipv4conn != nil {
s.refCount.Add(1)
go s.recv4(s.ipv4conn)
}
if s.ipv6conn != nil {
s.refCount.Add(1)
go s.recv6(s.ipv6conn)
}
}

// Shutdown closes all udp connections and unregisters the service
func (s *Server) Shutdown() {
s.shutdown()
s.refCount.Add(1)
go s.probe()
}

// SetText updates and announces the TXT records
Expand All @@ -208,15 +203,17 @@ func (s *Server) TTL(ttl uint32) {
s.ttl = ttl
}

// Shutdown server will close currently open connections & channel
func (s *Server) shutdown() error {
// Shutdown closes all udp connections and unregisters the service
func (s *Server) Shutdown() {
s.shutdownLock.Lock()
defer s.shutdownLock.Unlock()
if s.isShutdown {
return errors.New("server is already shutdown")
return
}

err := s.unregister()
if err := s.unregister(); err != nil {
log.Printf("failed to unregister: %s", err)
}

close(s.shouldShutdown)

Expand All @@ -228,20 +225,17 @@ func (s *Server) shutdown() error {
}

// Wait for connection and routines to be closed
s.shutdownEnd.Wait()
s.refCount.Wait()
s.isShutdown = true

return err
}

// recv is a long running routine to receive packets from an interface
// recv4 is a long running routine to receive packets from an interface
func (s *Server) recv4(c *ipv4.PacketConn) {
defer s.refCount.Done()
if c == nil {
return
}
buf := make([]byte, 65536)
s.shutdownEnd.Add(1)
defer s.shutdownEnd.Done()
for {
select {
case <-s.shouldShutdown:
Expand All @@ -260,14 +254,13 @@ func (s *Server) recv4(c *ipv4.PacketConn) {
}
}

// recv is a long running routine to receive packets from an interface
// recv6 is a long running routine to receive packets from an interface
func (s *Server) recv6(c *ipv6.PacketConn) {
defer s.refCount.Done()
if c == nil {
return
}
buf := make([]byte, 65536)
s.shutdownEnd.Add(1)
defer s.shutdownEnd.Done()
for {
select {
case <-s.shouldShutdown:
Expand Down Expand Up @@ -528,6 +521,8 @@ func (s *Server) serviceTypeName(resp *dns.Msg, ttl uint32) {
// Perform probing & announcement
//TODO: implement a proper probing & conflict resolution
func (s *Server) probe() {
defer s.refCount.Done()

q := new(dns.Msg)
q.SetQuestion(s.service.ServiceInstanceName(), dns.TypePTR)
q.RecursionDesired = false
Expand Down Expand Up @@ -555,16 +550,25 @@ func (s *Server) probe() {
}
q.Ns = []dns.RR{srv, txt}

randomizer := rand.New(rand.NewSource(time.Now().UnixNano()))

// Wait for a random duration uniformly distributed between 0 and 250 ms
// before sending the first probe packet.
time.Sleep(time.Duration(randomizer.Intn(250)) * time.Millisecond)
timer := time.NewTimer(time.Duration(rand.Intn(250)) * time.Millisecond)
defer timer.Stop()
select {
case <-timer.C:
case <-s.shouldShutdown:
return
}
for i := 0; i < 3; i++ {
if err := s.multicastResponse(q, 0); err != nil {
log.Println("[ERR] zeroconf: failed to send probe:", err.Error())
}
time.Sleep(250 * time.Millisecond)
timer.Reset(250 * time.Millisecond)
select {
case <-timer.C:
case <-s.shouldShutdown:
return
}
}

// From RFC6762
Expand All @@ -573,7 +577,7 @@ func (s *Server) probe() {
// packet loss, a responder MAY send up to eight unsolicited responses,
// provided that the interval between unsolicited responses increases by
// at least a factor of two with every response sent.
timeout := 1 * time.Second
timeout := time.Second
for i := 0; i < multicastRepetitions; i++ {
for _, intf := range s.ifaces {
resp := new(dns.Msg)
Expand All @@ -587,7 +591,12 @@ func (s *Server) probe() {
log.Println("[ERR] zeroconf: failed to send announcement:", err.Error())
}
}
time.Sleep(timeout)
timer.Reset(timeout)
select {
case <-timer.C:
case <-s.shouldShutdown:
return
}
timeout *= 2
}
}
Expand Down Expand Up @@ -719,7 +728,7 @@ func (s *Server) unicastResponse(resp *dns.Msg, ifIndex int, from net.Addr) erro
}
}

// multicastResponse us used to send a multicast response packet
// multicastResponse is used to send a multicast response packet
func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int) error {
buf, err := msg.Pack()
if err != nil {
Expand Down
18 changes: 18 additions & 0 deletions service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,24 @@ func startMDNS(t *testing.T, port int, name, service, domain string) {
log.Printf("Published service: %s, type: %s, domain: %s", name, service, domain)
}

func TestQuickShutdown(t *testing.T) {
server, err := Register(mdnsName, mdnsService, mdnsDomain, mdnsPort, []string{"txtv=0", "lo=1", "la=2"}, nil)
if err != nil {
t.Fatal(err)
}

done := make(chan struct{})
go func() {
defer close(done)
server.Shutdown()
}()
select {
case <-done:
case <-time.After(500 * time.Millisecond):
t.Fatal("shutdown took longer than 500ms")
}
}

func TestBasic(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
Expand Down