Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement HTTP 429 when too many requests per IP calls limit is reached #6

Merged
merged 9 commits into from
Jun 22, 2020
38 changes: 35 additions & 3 deletions connlimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@ package connlimit

import (
"errors"
"fmt"
"net"
"net/http"
"sync"
"sync/atomic"
"time"
)

var (
// ErrPerClientIPLimitReached is returned if accepting a new conn would exceed
// the per-client-ip limit set.
ErrPerClientIPLimitReached = errors.New("client connection limit reached")
tooManyConnsMsg = "Your IP is issuing too many concurrent connections, please rate limit your calls\n"
tooManyRequestsResponse = []byte(fmt.Sprintf("HTTP/1.1 429 Too Many Requests\r\n"+
"Content-Type: text/plain\r\n"+
"Content-Length: %d\r\n"+
"Connection: close\r\n\r\n%s", len(tooManyConnsMsg), tooManyConnsMsg))
)

// Limiter implements a simple limiter that tracks the number of connections
Expand Down Expand Up @@ -173,21 +180,23 @@ func (l *Limiter) SetConfig(c Config) {
l.cfg.Store(c)
}

// HTTPConnStateFunc returns a func that can be passed as the ConnState field of
// HTTPConnStateFuncWithErrorHandler returns a func that can be passed as the ConnState field of
// an http.Server. This intercepts new HTTP connections to the server and
// applies the limiting to new connections.
//
// Note that if the conn is hijacked from the HTTP server then it will be freed
// in the limiter as if it was closed. Servers that use Hijacking must implement
// their own calls if they need to continue limiting the number of concurrent
// hijacked connections.
func (l *Limiter) HTTPConnStateFunc() func(net.Conn, http.ConnState) {
// errorHandler MUST close the connection itself
func (l *Limiter) HTTPConnStateFuncWithErrorHandler(errorHandler func(error, net.Conn)) func(net.Conn, http.ConnState) {

return func(conn net.Conn, state http.ConnState) {
switch state {
case http.StateNew:
_, err := l.Accept(conn)
if err != nil {
conn.Close()
errorHandler(err, conn)
}
case http.StateHijacked:
l.freeConn(conn)
Expand All @@ -199,3 +208,26 @@ func (l *Limiter) HTTPConnStateFunc() func(net.Conn, http.ConnState) {
}
}
}

// HTTPConnStateFunc is here for ascending compatibility reasons.
func (l *Limiter) HTTPConnStateFunc() func(net.Conn, http.ConnState) {
return l.HTTPConnStateFuncWithErrorHandler(func(err error, conn net.Conn) {
conn.Close()
})
}

// HTTPConnStateFuncWithDefault429Handler return an HTTP 429 if too many connections occur.
// BEWARE that returning HTTP 429 is done on critical path, you might choose to use
// HTTPConnStateFuncWithErrorHandler if you want to use a non-blocking strategy.
func (l *Limiter) HTTPConnStateFuncWithDefault429Handler(writeDeadlineMaxDelay time.Duration) func(net.Conn, http.ConnState) {
return l.HTTPConnStateFuncWithErrorHandler(func(err error, conn net.Conn) {
if err == ErrPerClientIPLimitReached {
// We don't care about slow players
if writeDeadlineMaxDelay > 0 {
conn.SetDeadline(time.Now().Add(writeDeadlineMaxDelay))
}
conn.Write(tooManyRequestsResponse)
}
conn.Close()
})
}
32 changes: 31 additions & 1 deletion connlimit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,13 +281,40 @@ func TestHTTPServer(t *testing.T) {
lim := NewLimiter(Config{
MaxConnsPerClientIP: 5,
})
funToTest := func() func(net.Conn, http.ConnState) {
return lim.HTTPConnStateFunc()
}
testHTTPServerWithConnState(t, funToTest)
}

func TestHTTPServerWith429WithoutDuration(t *testing.T) {
lim := NewLimiter(Config{
MaxConnsPerClientIP: 5,
})
funToTest := func() func(net.Conn, http.ConnState) {
return lim.HTTPConnStateFuncWithDefault429Handler(time.Duration(0))
}
testHTTPServerWithConnState(t, funToTest)
}

func TestHTTPServerWith429WithDuration(t *testing.T) {
lim := NewLimiter(Config{
MaxConnsPerClientIP: 5,
})
funToTest := func() func(net.Conn, http.ConnState) {
return lim.HTTPConnStateFuncWithDefault429Handler(time.Duration(time.Second))
}
testHTTPServerWithConnState(t, funToTest)
}

func testHTTPServerWithConnState(t *testing.T, funToTest func() func(net.Conn, http.ConnState)) {

srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(1 * time.Second)
w.Write([]byte("OK"))
}))

srv.Config.ConnState = lim.HTTPConnStateFunc()
srv.Config.ConnState = funToTest()
srv.Start()

client := srv.Client()
Expand Down Expand Up @@ -315,6 +342,9 @@ func TestHTTPServer(t *testing.T) {
}
atomic.AddUint64(&reset, 1)
} else {
if resp.StatusCode == http.StatusTooManyRequests {
atomic.AddUint64(&reset, 1)
}
resp.Body.Close()
}
}(i)
Expand Down