Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating server to use ConnState to track idle connections #1010

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 0 additions & 80 deletions caddyhttp/httpserver/graceful.go

This file was deleted.

145 changes: 94 additions & 51 deletions caddyhttp/httpserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package httpserver

import (
"crypto/tls"
"errors"
"fmt"
"log"
"net"
Expand All @@ -25,11 +26,13 @@ type Server struct {
Server *http.Server
quicServer *h2quic.Server
listener net.Listener
listenerMu sync.Mutex
mu sync.Mutex
sites []*SiteConfig
connTimeout time.Duration // max time to wait for a connection before force stop
connWg sync.WaitGroup // one increment per connection
tlsGovChan chan struct{} // close to stop the TLS maintenance goroutine
closed bool
connTimeout time.Duration // max time to wait for a connection before force stop
connWg *sync.WaitGroup // one increment per connection
conns map[net.Conn]http.ConnState // store idle connections
tlsGovChan chan struct{} // close to stop the TLS maintenance goroutine
vhosts *vhostTrie
}

Expand All @@ -50,18 +53,12 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
vhosts: newVHostTrie(),
sites: group,
connTimeout: GracefulTimeout,
connWg: &sync.WaitGroup{},
}
s.Server.Handler = s // this is weird, but whatever
s.Server.ConnState = func(c net.Conn, cs http.ConnState) {
if cs == http.StateIdle {
s.listenerMu.Lock()
// server stopped, close idle connection
if s.listener == nil {
c.Close()
}
s.listenerMu.Unlock()
}
}

// Track connection state
s.connState()

// Disable HTTP/2 if desired
if !HTTP2 {
Expand All @@ -74,14 +71,6 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
s.Server.Handler = s.wrapWithSvcHeaders(s.Server.Handler)
}

// We have to bound our wg with one increment
// to prevent a "race condition" that is hard-coded
// into sync.WaitGroup.Wait() - basically, an add
// with a positive delta must be guaranteed to
// occur before Wait() is called on the wg.
// In a way, this kind of acts as a safety barrier.
s.connWg.Add(1)

// Set up TLS configuration
var tlsConfigs []*caddytls.Config
var err error
Expand Down Expand Up @@ -159,11 +148,7 @@ func (s *Server) Serve(ln net.Listener) error {
ln = tcpKeepAliveListener{TCPListener: tcpLn}
}

ln = newGracefulListener(ln, &s.connWg)

s.listenerMu.Lock()
s.listener = ln
s.listenerMu.Unlock()

if s.Server.TLSConfig != nil {
// Create TLS listener - note that we do not replace s.listener
Expand Down Expand Up @@ -292,40 +277,98 @@ func (s *Server) Address() string {

// Stop stops s gracefully (or forcefully after timeout) and
// closes its listener.
func (s *Server) Stop() (err error) {
func (s *Server) Stop() error {
if s.closed {
return errors.New("Server has been closed")
}

// Make sure a listener was set
if s.listener != nil {
// Close the listener to stop all new connections
if err := s.listener.Close(); err != nil {
return err
}
}

s.Server.SetKeepAlivesEnabled(false)
s.mu.Lock()
s.closed = true
s.mu.Unlock()

// Wait for any connections to finish
wait := make(chan struct{})
go func() {
defer close(wait)

// Closing this signals any TLS governor goroutines to exit
if s.tlsGovChan != nil {
close(s.tlsGovChan)
}

if runtime.GOOS != "windows" {
// force connections to close after timeout
done := make(chan struct{})
go func() {
s.connWg.Done() // decrement our initial increment used as a barrier
s.connWg.Wait()
close(done)
}()
s.connWg.Wait()
}()

// Wait for remaining connections to finish or
// force them all to close after timeout
select {
case <-time.After(s.connTimeout):
case <-done:
// We block until all active connections are closed or the connTimeout happens
select {
case <-time.After(s.connTimeout):
s.mu.Lock()
for c, st := range s.conns {
// Force close any idle and new connections.
if st == http.StateIdle || st == http.StateNew {
c.Close()
}
}
s.mu.Unlock()
return nil
case <-wait:
return nil
}
}

// Close the listener now; this stops the server without delay
s.listenerMu.Lock()
if s.listener != nil {
err = s.listener.Close()
s.listener = nil
}
s.listenerMu.Unlock()
// connState setups the ConnState tracking hook to know which connections are idle
func (s *Server) connState() {
// Set our ConnState to track idle connections
s.Server.ConnState = func(c net.Conn, cs http.ConnState) {
s.mu.Lock()
defer s.mu.Unlock()

switch cs {
case http.StateNew:
// New connections increment the WaitGroup and are added the the conns dictionary
s.connWg.Add(1)
if s.conns == nil {
s.conns = make(map[net.Conn]http.ConnState)
}
s.conns[c] = cs
case http.StateActive:
// Only update status to StateActive if it's in the conns dictionary
if _, ok := s.conns[c]; ok {
s.conns[c] = cs
}
case http.StateIdle:
// Only update status to StateIdle if it's in the conns dictionary
if _, ok := s.conns[c]; ok {
s.conns[c] = cs
}

// Closing this signals any TLS governor goroutines to exit
if s.tlsGovChan != nil {
close(s.tlsGovChan)
// If we've already closed then we need to close this connection.
// We don't allow connections to become idle after server is closed
if s.closed {
c.Close()
}
case http.StateHijacked, http.StateClosed:
// If the connection is hijacked or closed we forget it
s.forgetConn(c)
}
}
}

return
// forgetConn removes c from conns and decrements WaitGroup
func (s *Server) forgetConn(c net.Conn) {
if _, ok := s.conns[c]; ok {
delete(s.conns, c)
s.connWg.Done()
}
}

// sanitizePath collapses any ./ ../ /// madness
Expand Down
83 changes: 83 additions & 0 deletions caddyhttp/httpserver/server_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package httpserver

import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"time"
)

func TestAddress(t *testing.T) {
Expand All @@ -13,3 +17,82 @@ func TestAddress(t *testing.T) {
t.Errorf("Expected '%s' but got '%s'", want, got)
}
}

func TestStop(t *testing.T) {
// Create Server
s, err := NewServer("", nil)
if err != nil {
t.Fatal(err)
}

if err := s.Stop(); err != nil {
t.Error("Server errored while trying to Stop", err)
}
}

func TestServer(t *testing.T) {
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, client")
}))
defer ts.Close()

// Create Server
s, err := NewServer("", nil)
if err != nil {
t.Fatal(err)
}

// Reduce connTimeout for testing
s.connTimeout = 1 * time.Millisecond

s.Server.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "hello")
})

// Set the test server config to the Server
ts.Config = s.Server
ts.Start()

// Set listener
s.listener = ts.Listener

client := http.Client{}
res, err := client.Get(ts.URL)
if err != nil {
t.Fatal(err)
}

got, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}

if string(got) != "hello" {
t.Errorf("got %q, want hello", string(got))
}

// Make sure there is only 1 connection
s.mu.Lock()
if len(s.conns) < 1 {
t.Fatal("Should have 1 connections")
}
s.mu.Unlock()

// Stop the server
s.Stop()

// Try to connect to the server after it's closed
res, err = client.Get(ts.URL)

// This should always error because new connections are not allowed
if err == nil {
t.Fatal("Should not accept new connections after close")
}

// Make sure there are zero connections
s.mu.Lock()
if len(s.conns) < 0 {
t.Fatal("Should have 0 connections")
}
s.mu.Unlock()
}