diff --git a/http2/server.go b/http2/server.go index 0e670dedd..54655a509 100644 --- a/http2/server.go +++ b/http2/server.go @@ -2,17 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// TODO: replace all <-sc.doneServing with reads from the stream's cw -// instead, and make sure that on close we close all open -// streams. then remove doneServing? - -// TODO: re-audit GOAWAY support. Consider each incoming frame type and -// whether it should be ignored during graceful shutdown. - -// TODO: disconnect idle clients. GFE seems to do 4 minutes. make -// configurable? or maximum number of idle clients and remove the -// oldest? - // TODO: turn off the serve goroutine when idle, so // an idle conn only has the readFrames goroutine active. (which could // also be optimized probably to pin less memory in crypto/tls). This @@ -114,6 +103,11 @@ type Server struct { // PermitProhibitedCipherSuites, if true, permits the use of // cipher suites prohibited by the HTTP/2 spec. PermitProhibitedCipherSuites bool + + // IdleTimeout specifies how long until idle clients should be + // closed with a GOAWAY frame. PING frames are not considered + // activity for the purposes of IdleTimeout. + IdleTimeout time.Duration } func (s *Server) maxReadFrameSize() uint32 { @@ -390,6 +384,8 @@ type serverConn struct { goAwayCode ErrCode shutdownTimerCh <-chan time.Time // nil until used shutdownTimer *time.Timer // nil until used + idleTimer *time.Timer // nil if unused + idleTimerCh <-chan time.Time // nil if unused // Owned by the writeFrameAsync goroutine: headerWriteBuf bytes.Buffer @@ -681,6 +677,12 @@ func (sc *serverConn) serve() { sc.setConnState(http.StateActive) sc.setConnState(http.StateIdle) + if sc.srv.IdleTimeout != 0 { + sc.idleTimer = time.NewTimer(sc.srv.IdleTimeout) + defer sc.idleTimer.Stop() + sc.idleTimerCh = sc.idleTimer.C + } + go sc.readFrames() // closed by defer sc.conn.Close above settingsTimer := time.NewTimer(firstSettingsTimeout) @@ -709,6 +711,9 @@ func (sc *serverConn) serve() { case <-sc.shutdownTimerCh: sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr()) return + case <-sc.idleTimerCh: + sc.vlogf("connection is idle") + sc.goAway(ErrCodeNo) case fn := <-sc.testHookCh: fn(loopNum) } @@ -1114,12 +1119,18 @@ func (sc *serverConn) processPing(f *PingFrame) error { // PROTOCOL_ERROR." return ConnectionError(ErrCodeProtocol) } + if sc.inGoAway { + return nil + } sc.writeFrame(frameWriteMsg{write: writePingAck{f}}) return nil } func (sc *serverConn) processWindowUpdate(f *WindowUpdateFrame) error { sc.serveG.check() + if sc.inGoAway { + return nil + } switch { case f.StreamID != 0: // stream-level flow control state, st := sc.state(f.StreamID) @@ -1152,6 +1163,9 @@ func (sc *serverConn) processWindowUpdate(f *WindowUpdateFrame) error { func (sc *serverConn) processResetStream(f *RSTStreamFrame) error { sc.serveG.check() + if sc.inGoAway { + return nil + } state, st := sc.state(f.StreamID) if state == stateIdle { @@ -1181,6 +1195,9 @@ func (sc *serverConn) closeStream(st *stream, err error) { sc.setConnState(http.StateIdle) } delete(sc.streams, st.id) + if len(sc.streams) == 0 && sc.srv.IdleTimeout != 0 { + sc.idleTimer.Reset(sc.srv.IdleTimeout) + } if p := st.body; p != nil { // Return any buffered unread bytes worth of conn-level flow control. // See golang.org/issue/16481 @@ -1204,6 +1221,9 @@ func (sc *serverConn) processSettings(f *SettingsFrame) error { } return nil } + if sc.inGoAway { + return nil + } if err := f.ForeachSetting(sc.processSetting); err != nil { return err } @@ -1275,6 +1295,9 @@ func (sc *serverConn) processSettingInitialWindowSize(val uint32) error { func (sc *serverConn) processData(f *DataFrame) error { sc.serveG.check() + if sc.inGoAway { + return nil + } data := f.Data() // "If a DATA frame is received whose stream is not in "open" @@ -1412,6 +1435,10 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error { } sc.maxStreamID = id + if sc.idleTimer != nil { + sc.idleTimer.Stop() + } + ctx, cancelCtx := contextWithCancel(sc.baseCtx) st = &stream{ sc: sc, @@ -1524,6 +1551,9 @@ func (st *stream) processTrailerHeaders(f *MetaHeadersFrame) error { } func (sc *serverConn) processPriority(f *PriorityFrame) error { + if sc.inGoAway { + return nil + } adjustStreamPriority(sc.streams, f.StreamID, f.PriorityParam) return nil } diff --git a/http2/server_test.go b/http2/server_test.go index 31d12f2ea..e7309b94c 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -86,12 +86,15 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} } var onlyServer, quiet bool + h2server := new(Server) for _, opt := range opts { switch v := opt.(type) { case func(*tls.Config): v(tlsConfig) case func(*httptest.Server): v(ts) + case func(*Server): + v(h2server) case serverTesterOpt: switch v { case optOnlyServer: @@ -106,7 +109,7 @@ func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{} } } - ConfigureServer(ts.Config, &Server{}) + ConfigureServer(ts.Config, h2server) st := &serverTester{ t: t, @@ -3406,6 +3409,52 @@ func TestUnreadFlowControlReturned_Server(t *testing.T) { } +func TestServerIdleTimeout(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + }, func(h2s *Server) { + h2s.IdleTimeout = 500 * time.Millisecond + }) + defer st.Close() + + st.greet() + ga := st.wantGoAway() + if ga.ErrCode != ErrCodeNo { + t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode) + } +} + +func TestServerIdleTimeout_AfterRequest(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + const timeout = 250 * time.Millisecond + + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + time.Sleep(timeout * 2) + }, func(h2s *Server) { + h2s.IdleTimeout = timeout + }) + defer st.Close() + + st.greet() + + // Send a request which takes twice the timeout. Verifies the + // idle timeout doesn't fire while we're in a request: + st.bodylessReq1() + st.wantHeaders() + + // But the idle timeout should be rearmed after the request + // is done: + ga := st.wantGoAway() + if ga.ErrCode != ErrCodeNo { + t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode) + } +} + // grpc-go closes the Request.Body currently with a Read. // Verify that it doesn't race. // See https://github.com/grpc/grpc-go/pull/938 diff --git a/http2/transport_test.go b/http2/transport_test.go index f0af30ac4..2006a3d15 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -1947,8 +1947,17 @@ func TestTransportNewTLSConfig(t *testing.T) { }, } for i, tt := range tests { + // Ignore the session ticket keys part, which ends up populating + // unexported fields in the Config: + if tt.conf != nil { + tt.conf.SessionTicketsDisabled = true + } + tr := &Transport{TLSClientConfig: tt.conf} got := tr.newTLSConfig(tt.host) + + got.SessionTicketsDisabled = false + if !reflect.DeepEqual(got, tt.want) { t.Errorf("%d. got %#v; want %#v", i, got, tt.want) }