diff --git a/internal/transport/transport.go b/internal/transport/transport.go index bfab940bd0de..6aa2d7e671dd 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -285,6 +285,10 @@ type Stream struct { contentSubtype string } +func (s *Stream) ID() uint32 { + return s.id +} + // isHeaderSent is only valid on the server-side. func (s *Stream) isHeaderSent() bool { return atomic.LoadUint32(&s.headerSent) == 1 diff --git a/server.go b/server.go index ae369b71ab97..5b26d18a638b 100644 --- a/server.go +++ b/server.go @@ -712,15 +712,53 @@ func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) tr return st } +func floorCPUCount() uint32 { + n := uint32(runtime.NumCPU()) + for i := uint32(1 << 31); i >= 2; i >>= 1 { + if n&i > 0 { + return i + } + } + + return 1 +} + +// numWorkers defines the number of stream handling workers. After experiments +// with different CPU counts, using the floor of the number of CPUs available +// was found to be the number optimal for performance across the board (QPS, +// latency). +var numWorkers = floorCPUCount() + +// workerMask is used to perform bitwise AND operations instead of expensive +// module operations on integers. +var workerMask = numWorkers - 1 + func (s *Server) serveStreams(st transport.ServerTransport) { defer st.Close() var wg sync.WaitGroup - st.HandleStreams(func(stream *transport.Stream) { - wg.Add(1) + + streamChannels := make([]chan *transport.Stream, numWorkers) + for i := range streamChannels { + ch := make(chan *transport.Stream) go func() { - defer wg.Done() - s.handleStream(st, stream, s.traceInfo(st, stream)) + for stream := range ch { + s.handleStream(st, stream, s.traceInfo(st, stream)) + wg.Done() + } }() + streamChannels[i] = ch + } + + st.HandleStreams(func(stream *transport.Stream) { + wg.Add(1) + select { + case streamChannels[stream.ID() & workerMask] <- stream: + default: + go func() { + s.handleStream(st, stream, s.traceInfo(st, stream)) + wg.Done() + }() + } }, func(ctx context.Context, method string) context.Context { if !EnableTracing { return ctx @@ -729,6 +767,10 @@ func (s *Server) serveStreams(st transport.ServerTransport) { return trace.NewContext(ctx, tr) }) wg.Wait() + + for _, ch := range streamChannels { + close(ch) + } } var _ http.Handler = (*Server)(nil)