diff --git a/http2/frame.go b/http2/frame.go index 184ac45fe..c1f6b90dc 100644 --- a/http2/frame.go +++ b/http2/frame.go @@ -662,6 +662,15 @@ func (f *Framer) WriteData(streamID uint32, endStream bool, data []byte) error { // It is the caller's responsibility not to violate the maximum frame size // and to not call other Write methods concurrently. func (f *Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad []byte) error { + if err := f.startWriteDataPadded(streamID, endStream, data, pad); err != nil { + return err + } + return f.endWrite() +} + +// startWriteDataPadded is WriteDataPadded, but only writes the frame to the Framer's internal buffer. +// The caller should call endWrite to flush the frame to the underlying writer. +func (f *Framer) startWriteDataPadded(streamID uint32, endStream bool, data, pad []byte) error { if !validStreamID(streamID) && !f.AllowIllegalWrites { return errStreamID } @@ -691,7 +700,7 @@ func (f *Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad []by } f.wbuf = append(f.wbuf, data...) f.wbuf = append(f.wbuf, pad...) - return f.endWrite() + return nil } // A SettingsFrame conveys configuration parameters that affect how diff --git a/http2/server.go b/http2/server.go index 9bd7035bf..8cb14f3c9 100644 --- a/http2/server.go +++ b/http2/server.go @@ -843,8 +843,13 @@ type frameWriteResult struct { // and then reports when it's done. // At most one goroutine can be running writeFrameAsync at a time per // serverConn. -func (sc *serverConn) writeFrameAsync(wr FrameWriteRequest) { - err := wr.write.writeFrame(sc) +func (sc *serverConn) writeFrameAsync(wr FrameWriteRequest, wd *writeData) { + var err error + if wd == nil { + err = wr.write.writeFrame(sc) + } else { + err = sc.framer.endWrite() + } sc.wroteFrameCh <- frameWriteResult{wr: wr, err: err} } @@ -1251,9 +1256,16 @@ func (sc *serverConn) startFrameWrite(wr FrameWriteRequest) { sc.writingFrameAsync = false err := wr.write.writeFrame(sc) sc.wroteFrame(frameWriteResult{wr: wr, err: err}) + } else if wd, ok := wr.write.(*writeData); ok { + // Encode the frame in the serve goroutine, to ensure we don't have + // any lingering asynchronous references to data passed to Write. + // See https://go.dev/issue/58446. + sc.framer.startWriteDataPadded(wd.streamID, wd.endStream, wd.p, nil) + sc.writingFrameAsync = true + go sc.writeFrameAsync(wr, wd) } else { sc.writingFrameAsync = true - go sc.writeFrameAsync(wr) + go sc.writeFrameAsync(wr, nil) } } diff --git a/http2/server_test.go b/http2/server_test.go index 978cc37b4..d32b2d85b 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -4631,3 +4631,78 @@ func TestCanonicalHeaderCacheGrowth(t *testing.T) { } } } + +// TestServerWriteDoesNotRetainBufferAfterStreamClose checks for access to +// the slice passed to ResponseWriter.Write after Write returns. +// +// Terminating the request stream on the client causes Write to return. +// We should not access the slice after this point. +func TestServerWriteDoesNotRetainBufferAfterReturn(t *testing.T) { + donec := make(chan struct{}) + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + defer close(donec) + buf := make([]byte, 1<<20) + var i byte + for { + i++ + _, err := w.Write(buf) + for j := range buf { + buf[j] = byte(i) // trigger race detector + } + if err != nil { + return + } + } + }, optOnlyServer) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + + req, _ := http.NewRequest("GET", st.ts.URL, nil) + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + <-donec +} + +// TestServerWriteDoesNotRetainBufferAfterServerClose checks for access to +// the slice passed to ResponseWriter.Write after Write returns. +// +// Shutting down the Server causes Write to return. +// We should not access the slice after this point. +func TestServerWriteDoesNotRetainBufferAfterServerClose(t *testing.T) { + donec := make(chan struct{}, 1) + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + donec <- struct{}{} + defer close(donec) + buf := make([]byte, 1<<20) + var i byte + for { + i++ + _, err := w.Write(buf) + for j := range buf { + buf[j] = byte(i) + } + if err != nil { + return + } + } + }, optOnlyServer) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + + req, _ := http.NewRequest("GET", st.ts.URL, nil) + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + <-donec + st.ts.Config.Close() + <-donec +}