diff --git a/caddyhttp/httpserver/graceful.go b/caddyhttp/httpserver/graceful.go deleted file mode 100644 index f11a6c9aaef..00000000000 --- a/caddyhttp/httpserver/graceful.go +++ /dev/null @@ -1,80 +0,0 @@ -package httpserver - -import ( - "net" - "sync" - "syscall" -) - -// TODO: Should this be a generic graceful listener available in its own package or something? -// Also, passing in a WaitGroup is a little awkward. Why can't this listener just keep -// the waitgroup internal to itself? - -// newGracefulListener returns a gracefulListener that wraps l and -// uses wg (stored in the host server) to count connections. -func newGracefulListener(l net.Listener, wg *sync.WaitGroup) *gracefulListener { - gl := &gracefulListener{Listener: l, stop: make(chan error), connWg: wg} - go func() { - <-gl.stop - gl.Lock() - gl.stopped = true - gl.Unlock() - gl.stop <- gl.Listener.Close() - }() - return gl -} - -// gracefuListener is a net.Listener which can -// count the number of connections on it. Its -// methods mainly wrap net.Listener to be graceful. -type gracefulListener struct { - net.Listener - stop chan error - stopped bool - sync.Mutex // protects the stopped flag - connWg *sync.WaitGroup // pointer to the host's wg used for counting connections -} - -// Accept accepts a connection. -func (gl *gracefulListener) Accept() (c net.Conn, err error) { - c, err = gl.Listener.Accept() - if err != nil { - return - } - c = gracefulConn{Conn: c, connWg: gl.connWg} - gl.connWg.Add(1) - return -} - -// Close immediately closes the listener. -func (gl *gracefulListener) Close() error { - gl.Lock() - if gl.stopped { - gl.Unlock() - return syscall.EINVAL - } - gl.Unlock() - gl.stop <- nil - return <-gl.stop -} - -// gracefulConn represents a connection on a -// gracefulListener so that we can keep track -// of the number of connections, thus facilitating -// a graceful shutdown. -type gracefulConn struct { - net.Conn - connWg *sync.WaitGroup // pointer to the host server's connection waitgroup -} - -// Close closes c's underlying connection while updating the wg count. -func (c gracefulConn) Close() error { - err := c.Conn.Close() - if err != nil { - return err - } - // close can fail on http2 connections (as of Oct. 2015, before http2 in std lib) - // so don't decrement count unless close succeeds - c.connWg.Done() - return nil -} diff --git a/caddyhttp/httpserver/server.go b/caddyhttp/httpserver/server.go index f42804d6b57..3f60843f139 100644 --- a/caddyhttp/httpserver/server.go +++ b/caddyhttp/httpserver/server.go @@ -3,6 +3,7 @@ package httpserver import ( "crypto/tls" + "errors" "fmt" "log" "net" @@ -25,11 +26,13 @@ type Server struct { Server *http.Server quicServer *h2quic.Server listener net.Listener - listenerMu sync.Mutex + mu sync.Mutex sites []*SiteConfig - connTimeout time.Duration // max time to wait for a connection before force stop - connWg sync.WaitGroup // one increment per connection - tlsGovChan chan struct{} // close to stop the TLS maintenance goroutine + closed bool + connTimeout time.Duration // max time to wait for a connection before force stop + connWg *sync.WaitGroup // one increment per connection + conns map[net.Conn]http.ConnState // store idle connections + tlsGovChan chan struct{} // close to stop the TLS maintenance goroutine vhosts *vhostTrie } @@ -50,18 +53,12 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) { vhosts: newVHostTrie(), sites: group, connTimeout: GracefulTimeout, + connWg: &sync.WaitGroup{}, } s.Server.Handler = s // this is weird, but whatever - s.Server.ConnState = func(c net.Conn, cs http.ConnState) { - if cs == http.StateIdle { - s.listenerMu.Lock() - // server stopped, close idle connection - if s.listener == nil { - c.Close() - } - s.listenerMu.Unlock() - } - } + + // Track connection state + s.connState() // Disable HTTP/2 if desired if !HTTP2 { @@ -74,14 +71,6 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) { s.Server.Handler = s.wrapWithSvcHeaders(s.Server.Handler) } - // We have to bound our wg with one increment - // to prevent a "race condition" that is hard-coded - // into sync.WaitGroup.Wait() - basically, an add - // with a positive delta must be guaranteed to - // occur before Wait() is called on the wg. - // In a way, this kind of acts as a safety barrier. - s.connWg.Add(1) - // Set up TLS configuration var tlsConfigs []*caddytls.Config var err error @@ -159,11 +148,7 @@ func (s *Server) Serve(ln net.Listener) error { ln = tcpKeepAliveListener{TCPListener: tcpLn} } - ln = newGracefulListener(ln, &s.connWg) - - s.listenerMu.Lock() s.listener = ln - s.listenerMu.Unlock() if s.Server.TLSConfig != nil { // Create TLS listener - note that we do not replace s.listener @@ -292,40 +277,98 @@ func (s *Server) Address() string { // Stop stops s gracefully (or forcefully after timeout) and // closes its listener. -func (s *Server) Stop() (err error) { +func (s *Server) Stop() error { + if s.closed { + return errors.New("Server has been closed") + } + + // Make sure a listener was set + if s.listener != nil { + // Close the listener to stop all new connections + if err := s.listener.Close(); err != nil { + return err + } + } + s.Server.SetKeepAlivesEnabled(false) + s.mu.Lock() + s.closed = true + s.mu.Unlock() + + // Wait for any connections to finish + wait := make(chan struct{}) + go func() { + defer close(wait) + + // Closing this signals any TLS governor goroutines to exit + if s.tlsGovChan != nil { + close(s.tlsGovChan) + } - if runtime.GOOS != "windows" { - // force connections to close after timeout - done := make(chan struct{}) - go func() { - s.connWg.Done() // decrement our initial increment used as a barrier - s.connWg.Wait() - close(done) - }() + s.connWg.Wait() + }() - // Wait for remaining connections to finish or - // force them all to close after timeout - select { - case <-time.After(s.connTimeout): - case <-done: + // We block until all active connections are closed or the connTimeout happens + select { + case <-time.After(s.connTimeout): + s.mu.Lock() + for c, st := range s.conns { + // Force close any idle and new connections. + if st == http.StateIdle || st == http.StateNew { + c.Close() + } } + s.mu.Unlock() + return nil + case <-wait: + return nil } +} - // Close the listener now; this stops the server without delay - s.listenerMu.Lock() - if s.listener != nil { - err = s.listener.Close() - s.listener = nil - } - s.listenerMu.Unlock() +// connState setups the ConnState tracking hook to know which connections are idle +func (s *Server) connState() { + // Set our ConnState to track idle connections + s.Server.ConnState = func(c net.Conn, cs http.ConnState) { + s.mu.Lock() + defer s.mu.Unlock() + + switch cs { + case http.StateNew: + // New connections increment the WaitGroup and are added the the conns dictionary + s.connWg.Add(1) + if s.conns == nil { + s.conns = make(map[net.Conn]http.ConnState) + } + s.conns[c] = cs + case http.StateActive: + // Only update status to StateActive if it's in the conns dictionary + if _, ok := s.conns[c]; ok { + s.conns[c] = cs + } + case http.StateIdle: + // Only update status to StateIdle if it's in the conns dictionary + if _, ok := s.conns[c]; ok { + s.conns[c] = cs + } - // Closing this signals any TLS governor goroutines to exit - if s.tlsGovChan != nil { - close(s.tlsGovChan) + // If we've already closed then we need to close this connection. + // We don't allow connections to become idle after server is closed + if s.closed { + c.Close() + } + case http.StateHijacked, http.StateClosed: + // If the connection is hijacked or closed we forget it + s.forgetConn(c) + } } +} - return +// forgetConn removes c from conns and decrements WaitGroup +func (s *Server) forgetConn(c net.Conn) { + if _, ok := s.conns[c]; ok { + delete(s.conns, c) + s.connWg.Done() + } } // sanitizePath collapses any ./ ../ /// madness diff --git a/caddyhttp/httpserver/server_test.go b/caddyhttp/httpserver/server_test.go index d8e53c100e9..354575a3c25 100644 --- a/caddyhttp/httpserver/server_test.go +++ b/caddyhttp/httpserver/server_test.go @@ -1,8 +1,12 @@ package httpserver import ( + "fmt" + "io/ioutil" "net/http" + "net/http/httptest" "testing" + "time" ) func TestAddress(t *testing.T) { @@ -13,3 +17,82 @@ func TestAddress(t *testing.T) { t.Errorf("Expected '%s' but got '%s'", want, got) } } + +func TestStop(t *testing.T) { + // Create Server + s, err := NewServer("", nil) + if err != nil { + t.Fatal(err) + } + + if err := s.Stop(); err != nil { + t.Error("Server errored while trying to Stop", err) + } +} + +func TestServer(t *testing.T) { + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, client") + })) + defer ts.Close() + + // Create Server + s, err := NewServer("", nil) + if err != nil { + t.Fatal(err) + } + + // Reduce connTimeout for testing + s.connTimeout = 1 * time.Millisecond + + s.Server.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "hello") + }) + + // Set the test server config to the Server + ts.Config = s.Server + ts.Start() + + // Set listener + s.listener = ts.Listener + + client := http.Client{} + res, err := client.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + + got, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + + if string(got) != "hello" { + t.Errorf("got %q, want hello", string(got)) + } + + // Make sure there is only 1 connection + s.mu.Lock() + if len(s.conns) < 1 { + t.Fatal("Should have 1 connections") + } + s.mu.Unlock() + + // Stop the server + s.Stop() + + // Try to connect to the server after it's closed + res, err = client.Get(ts.URL) + + // This should always error because new connections are not allowed + if err == nil { + t.Fatal("Should not accept new connections after close") + } + + // Make sure there are zero connections + s.mu.Lock() + if len(s.conns) < 0 { + t.Fatal("Should have 0 connections") + } + s.mu.Unlock() +}