diff --git a/connlimit.go b/connlimit.go index d445df1..248d969 100644 --- a/connlimit.go +++ b/connlimit.go @@ -2,16 +2,23 @@ package connlimit import ( "errors" + "fmt" "net" "net/http" "sync" "sync/atomic" + "time" ) var ( // ErrPerClientIPLimitReached is returned if accepting a new conn would exceed // the per-client-ip limit set. ErrPerClientIPLimitReached = errors.New("client connection limit reached") + tooManyConnsMsg = "Your IP is issuing too many concurrent connections, please rate limit your calls\n" + tooManyRequestsResponse = []byte(fmt.Sprintf("HTTP/1.1 429 Too Many Requests\r\n"+ + "Content-Type: text/plain\r\n"+ + "Content-Length: %d\r\n"+ + "Connection: close\r\n\r\n%s", len(tooManyConnsMsg), tooManyConnsMsg)) ) // Limiter implements a simple limiter that tracks the number of connections @@ -173,7 +180,7 @@ func (l *Limiter) SetConfig(c Config) { l.cfg.Store(c) } -// HTTPConnStateFunc returns a func that can be passed as the ConnState field of +// HTTPConnStateFuncWithErrorHandler returns a func that can be passed as the ConnState field of // an http.Server. This intercepts new HTTP connections to the server and // applies the limiting to new connections. // @@ -181,13 +188,15 @@ func (l *Limiter) SetConfig(c Config) { // in the limiter as if it was closed. Servers that use Hijacking must implement // their own calls if they need to continue limiting the number of concurrent // hijacked connections. -func (l *Limiter) HTTPConnStateFunc() func(net.Conn, http.ConnState) { +// errorHandler MUST close the connection itself +func (l *Limiter) HTTPConnStateFuncWithErrorHandler(errorHandler func(error, net.Conn)) func(net.Conn, http.ConnState) { + return func(conn net.Conn, state http.ConnState) { switch state { case http.StateNew: _, err := l.Accept(conn) if err != nil { - conn.Close() + errorHandler(err, conn) } case http.StateHijacked: l.freeConn(conn) @@ -199,3 +208,26 @@ func (l *Limiter) HTTPConnStateFunc() func(net.Conn, http.ConnState) { } } } + +// HTTPConnStateFunc is here for ascending compatibility reasons. +func (l *Limiter) HTTPConnStateFunc() func(net.Conn, http.ConnState) { + return l.HTTPConnStateFuncWithErrorHandler(func(err error, conn net.Conn) { + conn.Close() + }) +} + +// HTTPConnStateFuncWithDefault429Handler return an HTTP 429 if too many connections occur. +// BEWARE that returning HTTP 429 is done on critical path, you might choose to use +// HTTPConnStateFuncWithErrorHandler if you want to use a non-blocking strategy. +func (l *Limiter) HTTPConnStateFuncWithDefault429Handler(writeDeadlineMaxDelay time.Duration) func(net.Conn, http.ConnState) { + return l.HTTPConnStateFuncWithErrorHandler(func(err error, conn net.Conn) { + if err == ErrPerClientIPLimitReached { + // We don't care about slow players + if writeDeadlineMaxDelay > 0 { + conn.SetDeadline(time.Now().Add(writeDeadlineMaxDelay)) + } + conn.Write(tooManyRequestsResponse) + } + conn.Close() + }) +} diff --git a/connlimit_test.go b/connlimit_test.go index 59c7360..7560820 100644 --- a/connlimit_test.go +++ b/connlimit_test.go @@ -281,13 +281,40 @@ func TestHTTPServer(t *testing.T) { lim := NewLimiter(Config{ MaxConnsPerClientIP: 5, }) + funToTest := func() func(net.Conn, http.ConnState) { + return lim.HTTPConnStateFunc() + } + testHTTPServerWithConnState(t, funToTest) +} + +func TestHTTPServerWith429WithoutDuration(t *testing.T) { + lim := NewLimiter(Config{ + MaxConnsPerClientIP: 5, + }) + funToTest := func() func(net.Conn, http.ConnState) { + return lim.HTTPConnStateFuncWithDefault429Handler(time.Duration(0)) + } + testHTTPServerWithConnState(t, funToTest) +} + +func TestHTTPServerWith429WithDuration(t *testing.T) { + lim := NewLimiter(Config{ + MaxConnsPerClientIP: 5, + }) + funToTest := func() func(net.Conn, http.ConnState) { + return lim.HTTPConnStateFuncWithDefault429Handler(time.Duration(time.Second)) + } + testHTTPServerWithConnState(t, funToTest) +} + +func testHTTPServerWithConnState(t *testing.T, funToTest func() func(net.Conn, http.ConnState)) { srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(1 * time.Second) w.Write([]byte("OK")) })) - srv.Config.ConnState = lim.HTTPConnStateFunc() + srv.Config.ConnState = funToTest() srv.Start() client := srv.Client() @@ -315,6 +342,9 @@ func TestHTTPServer(t *testing.T) { } atomic.AddUint64(&reset, 1) } else { + if resp.StatusCode == http.StatusTooManyRequests { + atomic.AddUint64(&reset, 1) + } resp.Body.Close() } }(i)