diff --git a/internal/transport/transport.go b/internal/transport/transport.go index bfab940bd0de..8dcd9266f9d7 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -285,6 +285,11 @@ type Stream struct { contentSubtype string } +// ID returns the stream ID. +func (s *Stream) ID() uint32 { + return s.id +} + // isHeaderSent is only valid on the server-side. func (s *Stream) isHeaderSent() bool { return atomic.LoadUint32(&s.headerSent) == 1 diff --git a/server.go b/server.go index ae369b71ab97..60080906038e 100644 --- a/server.go +++ b/server.go @@ -712,15 +712,67 @@ func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) tr return st } +func floorCPUCount() uint32 { + n := uint32(runtime.NumCPU()) + for i := uint32(1 << 31); i >= 2; i >>= 1 { + if n&i > 0 { + return i + } + } + + return 1 +} + +// numWorkers defines the number of stream handling workers. After experiments +// with different CPU counts, using the floor of the number of CPUs available +// was found to be the number optimal for performance across the board (QPS, +// latency). +var numWorkers = floorCPUCount() + +// workerMask is used to perform bitwise AND operations instead of expensive +// module operations on integers. +var workerMask = numWorkers - 1 + +// workerStackReset defines how often the stack must be reset. Every N +// requests, by spawning a new goroutine in its place, a worker can reset its +// stack so that large stacks don't live in memory forever. 2^16 should allow +// each goroutine stack to live for at least a few seconds in a typical +// workload (assuming a QPS of a few thousand requests/sec). +const workerStackReset = 1 << 16 + +func (s *Server) streamWorker(st transport.ServerTransport, wg *sync.WaitGroup, ch chan *transport.Stream) { + completed := 0 + for stream := range ch { + s.handleStream(st, stream, s.traceInfo(st, stream)) + wg.Done() + completed++ + if completed == workerStackReset { + go s.streamWorker(st, wg, ch) + return + } + } +} + func (s *Server) serveStreams(st transport.ServerTransport) { defer st.Close() var wg sync.WaitGroup + + streamChannels := make([]chan *transport.Stream, numWorkers) + for i := range streamChannels { + streamChannels[i] = make(chan *transport.Stream) + go s.streamWorker(st, &wg, streamChannels[i]) + } + st.HandleStreams(func(stream *transport.Stream) { wg.Add(1) - go func() { - defer wg.Done() - s.handleStream(st, stream, s.traceInfo(st, stream)) - }() + select { + case streamChannels[stream.ID()&workerMask] <- stream: + default: + go func() { + s.handleStream(st, stream, s.traceInfo(st, stream)) + wg.Done() + }() + } }, func(ctx context.Context, method string) context.Context { if !EnableTracing { return ctx @@ -729,6 +781,10 @@ func (s *Server) serveStreams(st transport.ServerTransport) { return trace.NewContext(ctx, tr) }) wg.Wait() + + for _, ch := range streamChannels { + close(ch) + } } var _ http.Handler = (*Server)(nil)