diff --git a/http2/transport.go b/http2/transport.go index 54acc1e360..7a07ad5913 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -668,6 +668,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro cc.inflow.add(transportDefaultConnFlow + initialWindowSize) cc.bw.Flush() if cc.werr != nil { + cc.Close() return nil, cc.werr } @@ -1033,6 +1034,15 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf bodyWriter := cc.t.getBodyWriterState(cs, body) cs.on100 = bodyWriter.on100 + defer func() { + cc.wmu.Lock() + werr := cc.werr + cc.wmu.Unlock() + if werr != nil { + cc.Close() + } + }() + cc.wmu.Lock() endStream := !hasBody && !hasTrailers werr := cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs) diff --git a/http2/transport_test.go b/http2/transport_test.go index 1424f818b2..5995254183 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -4568,3 +4568,65 @@ func TestClientConnTooIdle(t *testing.T) { } } } + +type fakeConnErr struct { + net.Conn + writeErr error + closed bool +} + +func (fce *fakeConnErr) Write(b []byte) (n int, err error) { + return 0, fce.writeErr +} + +func (fce *fakeConnErr) Close() error { + fce.closed = true + return nil +} + +// issue 39337: close the connection on a failed write +func TestTransportNewClientConnCloseOnWriteError(t *testing.T) { + tr := &Transport{} + writeErr := errors.New("write error") + fakeConn := &fakeConnErr{writeErr: writeErr} + _, err := tr.NewClientConn(fakeConn) + if err != writeErr { + t.Fatalf("expected %v, got %v", writeErr, err) + } + if !fakeConn.closed { + t.Error("expected closed conn") + } +} + +func TestTransportRoundtripCloseOnWriteError(t *testing.T) { + req, err := http.NewRequest("GET", "https://dummy.tld/", nil) + if err != nil { + t.Fatal(err) + } + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optOnlyServer) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure} + defer tr.CloseIdleConnections() + cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false) + if err != nil { + t.Fatal(err) + } + + writeErr := errors.New("write error") + cc.wmu.Lock() + cc.werr = writeErr + cc.wmu.Unlock() + + _, err = cc.RoundTrip(req) + if err != writeErr { + t.Fatalf("expected %v, got %v", writeErr, err) + } + + cc.mu.Lock() + closed := cc.closed + cc.mu.Unlock() + if !closed { + t.Fatal("expected closed") + } +}