diff --git a/src/net/http/transport.go b/src/net/http/transport.go index 2a549a9576b9cf..411f6b29120100 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -1478,6 +1478,7 @@ func (t *Transport) dialConnFor(w *wantConn) { defer w.afterDial() ctx := w.getCtxForDial() if ctx == nil { + t.decConnsPerHost(w.key) return } diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index 698a43530ad6b4..55222a67635114 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -730,6 +730,56 @@ func testTransportMaxConnsPerHost(t *testing.T, mode testMode) { } } +func TestTransportMaxConnsPerHostDialCancellation(t *testing.T) { + run(t, testTransportMaxConnsPerHostDialCancellation, + testNotParallel, // because test uses SetPendingDialHooks + []testMode{http1Mode, https1Mode, http2Mode}, + ) +} + +func testTransportMaxConnsPerHostDialCancellation(t *testing.T, mode testMode) { + CondSkipHTTP2(t) + + h := HandlerFunc(func(w ResponseWriter, r *Request) { + _, err := w.Write([]byte("foo")) + if err != nil { + t.Fatalf("Write: %v", err) + } + }) + + cst := newClientServerTest(t, mode, h) + defer cst.close() + ts := cst.ts + c := ts.Client() + tr := c.Transport.(*Transport) + tr.MaxConnsPerHost = 1 + + // This request is cancelled when dial is queued, which preempts dialing. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + SetPendingDialHooks(cancel, nil) + defer SetPendingDialHooks(nil, nil) + + req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil) + _, err := c.Do(req) + if !errors.Is(err, context.Canceled) { + t.Errorf("expected error %v, got %v", context.Canceled, err) + } + + // This request should succeed. + SetPendingDialHooks(nil, nil) + req, _ = NewRequest("GET", ts.URL, nil) + resp, err := c.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body failed: %v", err) + } +} + func TestTransportRemovesDeadIdleConnections(t *testing.T) { run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode}) }