Skip to content

Commit

Permalink
Implement GetRejectedConnectionsCount function (#1704)
Browse files Browse the repository at this point in the history
* Implement `GetRejectedConnectionsCount`

* Implement test for `GetRejectedConnectionsCount`
  • Loading branch information
mopeneko authored Feb 10, 2024
1 parent dfb7e62 commit b430b88
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
10 changes: 10 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,8 @@ type Server struct {
open int32
stop int32
done chan struct{}

rejectedRequestsCount uint32
}

// TimeoutHandler creates RequestHandler, which returns StatusRequestTimeout
Expand Down Expand Up @@ -1828,6 +1830,7 @@ func (s *Server) Serve(ln net.Listener) error {
atomic.AddInt32(&s.open, 1)
if !wp.Serve(c) {
atomic.AddInt32(&s.open, -1)
atomic.AddUint32(&s.rejectedRequestsCount, 1)
s.writeFastError(c, StatusServiceUnavailable,
"The connection cannot be served because Server.Concurrency limit exceeded")
c.Close()
Expand Down Expand Up @@ -2073,6 +2076,13 @@ func (s *Server) GetOpenConnectionsCount() int32 {
return atomic.LoadInt32(&s.open)
}

// GetRejectedConnectionsCount returns a number of rejected connections.
//
// This function is intended be used by monitoring systems.
func (s *Server) GetRejectedConnectionsCount() uint32 {
return atomic.LoadUint32(&s.rejectedRequestsCount)
}

func (s *Server) getConcurrency() int {
n := s.Concurrency
if n <= 0 {
Expand Down
56 changes: 56 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,62 @@ func TestServerConcurrencyLimit(t *testing.T) {
}
}

func TestRejectedRequestsCount(t *testing.T) {
t.Parallel()

s := &Server{
Handler: func(ctx *RequestCtx) {
ctx.WriteString("OK") //nolint:errcheck
},
Concurrency: 1,
Logger: &testLogger{},
}

ln := fasthttputil.NewInmemoryListener()

serverCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
t.Errorf("unexpected error: %v", err)
}
close(serverCh)
}()

clientCh := make(chan struct{})
expectedCount := 5
go func() {
for i := 0; i < expectedCount+1; i++ {
_, err := ln.Dial()
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}

if cnt := s.GetRejectedConnectionsCount(); cnt != uint32(expectedCount) {
t.Errorf("unexpected rejected connections count: %d. Expecting %d",
cnt, expectedCount)
}

close(clientCh)
}()

select {
case <-clientCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}

if err := ln.Close(); err != nil {
t.Fatalf("unexpected error: %v", err)
}

select {
case <-serverCh:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}

func TestServerWriteFastError(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit b430b88

Please sign in to comment.