Skip to content

Commit

Permalink
server: prohibit more than MaxConcurrentStreams handlers from running…
Browse files Browse the repository at this point in the history
… at once (#6703) (#6708)
  • Loading branch information
dfawley authored Oct 10, 2023
1 parent bd1f038 commit 5efd7bd
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 45 deletions.
11 changes: 3 additions & 8 deletions internal/transport/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,10 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
ID: http2.SettingMaxFrameSize,
Val: http2MaxFrameLen,
}}
// TODO(zhaoq): Have a better way to signal "no limit" because 0 is
// permitted in the HTTP2 spec.
maxStreams := config.MaxStreams
if maxStreams == 0 {
maxStreams = math.MaxUint32
} else {
if config.MaxStreams != math.MaxUint32 {
isettings = append(isettings, http2.Setting{
ID: http2.SettingMaxConcurrentStreams,
Val: maxStreams,
Val: config.MaxStreams,
})
}
dynamicWindow := true
Expand Down Expand Up @@ -258,7 +253,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
framer: framer,
readerDone: make(chan struct{}),
writerDone: make(chan struct{}),
maxStreams: maxStreams,
maxStreams: config.MaxStreams,
inTapHandle: config.InTapHandle,
fc: &trInFlow{limit: uint32(icwz)},
state: reachable,
Expand Down
35 changes: 19 additions & 16 deletions internal/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,9 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
if err != nil {
return
}
if serverConfig.MaxStreams == 0 {
serverConfig.MaxStreams = math.MaxUint32
}
transport, err := NewServerTransport(conn, serverConfig)
if err != nil {
return
Expand Down Expand Up @@ -442,8 +445,8 @@ func setUpServerOnly(t *testing.T, port int, sc *ServerConfig, ht hType) *server
return server
}

func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, *http2Client, func()) {
return setUpWithOptions(t, port, &ServerConfig{MaxStreams: maxStreams}, ht, ConnectOptions{})
func setUp(t *testing.T, port int, ht hType) (*server, *http2Client, func()) {
return setUpWithOptions(t, port, &ServerConfig{}, ht, ConnectOptions{})
}

func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts ConnectOptions) (*server, *http2Client, func()) {
Expand Down Expand Up @@ -538,7 +541,7 @@ func (s) TestInflightStreamClosing(t *testing.T) {

// Tests that when streamID > MaxStreamId, the current client transport drains.
func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
server, ct, cancel := setUp(t, 0, normal)
defer cancel()
defer server.stop()
callHdr := &CallHdr{
Expand Down Expand Up @@ -583,7 +586,7 @@ func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
}

func (s) TestClientSendAndReceive(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
server, ct, cancel := setUp(t, 0, normal)
defer cancel()
callHdr := &CallHdr{
Host: "localhost",
Expand Down Expand Up @@ -623,7 +626,7 @@ func (s) TestClientSendAndReceive(t *testing.T) {
}

func (s) TestClientErrorNotify(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
server, ct, cancel := setUp(t, 0, normal)
defer cancel()
go server.stop()
// ct.reader should detect the error and activate ct.Error().
Expand Down Expand Up @@ -657,7 +660,7 @@ func performOneRPC(ct ClientTransport) {
}

func (s) TestClientMix(t *testing.T) {
s, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
s, ct, cancel := setUp(t, 0, normal)
defer cancel()
time.AfterFunc(time.Second, s.stop)
go func(ct ClientTransport) {
Expand All @@ -671,7 +674,7 @@ func (s) TestClientMix(t *testing.T) {
}

func (s) TestLargeMessage(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
server, ct, cancel := setUp(t, 0, normal)
defer cancel()
callHdr := &CallHdr{
Host: "localhost",
Expand Down Expand Up @@ -806,7 +809,7 @@ func (s) TestLargeMessageWithDelayRead(t *testing.T) {
// proceed until they complete naturally, while not allowing creation of new
// streams during this window.
func (s) TestGracefulClose(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, pingpong)
server, ct, cancel := setUp(t, 0, pingpong)
defer cancel()
defer func() {
// Stop the server's listener to make the server's goroutines terminate
Expand Down Expand Up @@ -872,7 +875,7 @@ func (s) TestGracefulClose(t *testing.T) {
}

func (s) TestLargeMessageSuspension(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended)
server, ct, cancel := setUp(t, 0, suspended)
defer cancel()
callHdr := &CallHdr{
Host: "localhost",
Expand Down Expand Up @@ -980,7 +983,7 @@ func (s) TestMaxStreams(t *testing.T) {
}

func (s) TestServerContextCanceledOnClosedConnection(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended)
server, ct, cancel := setUp(t, 0, suspended)
defer cancel()
callHdr := &CallHdr{
Host: "localhost",
Expand Down Expand Up @@ -1452,7 +1455,7 @@ func (s) TestClientWithMisbehavedServer(t *testing.T) {
var encodingTestStatus = status.New(codes.Internal, "\n")

func (s) TestEncodingRequiredStatus(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, encodingRequiredStatus)
server, ct, cancel := setUp(t, 0, encodingRequiredStatus)
defer cancel()
callHdr := &CallHdr{
Host: "localhost",
Expand Down Expand Up @@ -1480,7 +1483,7 @@ func (s) TestEncodingRequiredStatus(t *testing.T) {
}

func (s) TestInvalidHeaderField(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField)
server, ct, cancel := setUp(t, 0, invalidHeaderField)
defer cancel()
callHdr := &CallHdr{
Host: "localhost",
Expand All @@ -1502,7 +1505,7 @@ func (s) TestInvalidHeaderField(t *testing.T) {
}

func (s) TestHeaderChanClosedAfterReceivingAnInvalidHeader(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField)
server, ct, cancel := setUp(t, 0, invalidHeaderField)
defer cancel()
defer server.stop()
defer ct.Close(fmt.Errorf("closed manually by test"))
Expand Down Expand Up @@ -2170,7 +2173,7 @@ func (s) TestPingPong1MB(t *testing.T) {

// This is a stress-test of flow control logic.
func runPingPongTest(t *testing.T, msgSize int) {
server, client, cancel := setUp(t, 0, 0, pingpong)
server, client, cancel := setUp(t, 0, pingpong)
defer cancel()
defer server.stop()
defer client.Close(fmt.Errorf("closed manually by test"))
Expand Down Expand Up @@ -2252,7 +2255,7 @@ func (s) TestHeaderTblSize(t *testing.T) {
}
}()

server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
server, ct, cancel := setUp(t, 0, normal)
defer cancel()
defer ct.Close(fmt.Errorf("closed manually by test"))
defer server.stop()
Expand Down Expand Up @@ -2611,7 +2614,7 @@ func TestConnectionError_Unwrap(t *testing.T) {

func (s) TestPeerSetInServerContext(t *testing.T) {
// create client and server transports.
server, client, cancel := setUp(t, 0, math.MaxUint32, normal)
server, client, cancel := setUp(t, 0, normal)
defer cancel()
defer server.stop()
defer client.Close(fmt.Errorf("closed manually by test"))
Expand Down
69 changes: 48 additions & 21 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,6 @@ type serviceInfo struct {
mdata interface{}
}

type serverWorkerData struct {
st transport.ServerTransport
wg *sync.WaitGroup
stream *transport.Stream
}

// Server is a gRPC server to serve RPC requests.
type Server struct {
opts serverOptions
Expand All @@ -145,7 +139,7 @@ type Server struct {
channelzID *channelz.Identifier
czData *channelzData

serverWorkerChannel chan *serverWorkerData
serverWorkerChannel chan func()
}

type serverOptions struct {
Expand Down Expand Up @@ -177,6 +171,7 @@ type serverOptions struct {
}

var defaultServerOptions = serverOptions{
maxConcurrentStreams: math.MaxUint32,
maxReceiveMessageSize: defaultServerMaxReceiveMessageSize,
maxSendMessageSize: defaultServerMaxSendMessageSize,
connectionTimeout: 120 * time.Second,
Expand Down Expand Up @@ -387,6 +382,9 @@ func MaxSendMsgSize(m int) ServerOption {
// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
// of concurrent streams to each ServerTransport.
func MaxConcurrentStreams(n uint32) ServerOption {
if n == 0 {
n = math.MaxUint32
}
return newFuncServerOption(func(o *serverOptions) {
o.maxConcurrentStreams = n
})
Expand Down Expand Up @@ -567,24 +565,19 @@ const serverWorkerResetThreshold = 1 << 16
// [1] https://github.com/golang/go/issues/18138
func (s *Server) serverWorker() {
for completed := 0; completed < serverWorkerResetThreshold; completed++ {
data, ok := <-s.serverWorkerChannel
f, ok := <-s.serverWorkerChannel
if !ok {
return
}
s.handleSingleStream(data)
f()
}
go s.serverWorker()
}

func (s *Server) handleSingleStream(data *serverWorkerData) {
defer data.wg.Done()
s.handleStream(data.st, data.stream, s.traceInfo(data.st, data.stream))
}

// initServerWorkers creates worker goroutines and a channel to process incoming
// connections to reduce the time spent overall on runtime.morestack.
func (s *Server) initServerWorkers() {
s.serverWorkerChannel = make(chan *serverWorkerData)
s.serverWorkerChannel = make(chan func())
for i := uint32(0); i < s.opts.numServerWorkers; i++ {
go s.serverWorker()
}
Expand Down Expand Up @@ -943,21 +936,26 @@ func (s *Server) serveStreams(st transport.ServerTransport) {
defer st.Close(errors.New("finished serving streams for the server transport"))
var wg sync.WaitGroup

streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams)
st.HandleStreams(func(stream *transport.Stream) {
wg.Add(1)

streamQuota.acquire()
f := func() {
defer streamQuota.release()
defer wg.Done()
s.handleStream(st, stream, s.traceInfo(st, stream))
}

if s.opts.numServerWorkers > 0 {
data := &serverWorkerData{st: st, wg: &wg, stream: stream}
select {
case s.serverWorkerChannel <- data:
case s.serverWorkerChannel <- f:
return
default:
// If all stream workers are busy, fallback to the default code path.
}
}
go func() {
defer wg.Done()
s.handleStream(st, stream, s.traceInfo(st, stream))
}()
go f()
}, func(ctx context.Context, method string) context.Context {
if !EnableTracing {
return ctx
Expand Down Expand Up @@ -2052,3 +2050,32 @@ func validateSendCompressor(name, clientCompressors string) error {
}
return fmt.Errorf("client does not support compressor %q", name)
}

// atomicSemaphore implements a blocking, counting semaphore. acquire should be
// called synchronously; release may be called asynchronously.
type atomicSemaphore struct {
n int64
wait chan struct{}
}

func (q *atomicSemaphore) acquire() {
if atomic.AddInt64(&q.n, -1) < 0 {
// We ran out of quota. Block until a release happens.
<-q.wait
}
}

func (q *atomicSemaphore) release() {
// N.B. the "<= 0" check below should allow for this to work with multiple
// concurrent calls to acquire, but also note that with synchronous calls to
// acquire, as our system does, n will never be less than -1. There are
// fairness issues (queuing) to consider if this was to be generalized.
if atomic.AddInt64(&q.n, 1) <= 0 {
// An acquire was waiting on us. Unblock it.
q.wait <- struct{}{}
}
}

func newHandlerQuota(n uint32) *atomicSemaphore {
return &atomicSemaphore{n: int64(n), wait: make(chan struct{}, 1)}
}
Loading

0 comments on commit 5efd7bd

Please sign in to comment.