From 098a87fb1930b9ef99d394fe1bca75f1bd74ce8d Mon Sep 17 00:00:00 2001
From: Julian Tibble <tibbes@github.com>
Date: Wed, 14 Feb 2024 13:24:52 +0000
Subject: [PATCH] net/http: add missing call to decConnsPerHost

A recent change to Transport.dialConnFor introduced an early return that
skipped dialing. This path did not call decConnsPerHost, which can cause
subsequent HTTP calls to hang if Transport.MaxConnsPerHost is set.

Fixes: #65705

Change-Id: I157591114b02a3a66488d3ead7f1e6dbd374a41c
Reviewed-on: https://go-review.googlesource.com/c/go/+/564036
Reviewed-by: Damien Neil <dneil@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Damien Neil <dneil@google.com>
Reviewed-by: Than McIntosh <thanm@google.com>
---
 src/net/http/transport.go      |  1 +
 src/net/http/transport_test.go | 50 ++++++++++++++++++++++++++++++++++
 2 files changed, 51 insertions(+)

diff --git a/src/net/http/transport.go b/src/net/http/transport.go
index 2a549a9576b9cf..411f6b29120100 100644
--- a/src/net/http/transport.go
+++ b/src/net/http/transport.go
@@ -1478,6 +1478,7 @@ func (t *Transport) dialConnFor(w *wantConn) {
 	defer w.afterDial()
 	ctx := w.getCtxForDial()
 	if ctx == nil {
+		t.decConnsPerHost(w.key)
 		return
 	}
 
diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go
index 698a43530ad6b4..55222a67635114 100644
--- a/src/net/http/transport_test.go
+++ b/src/net/http/transport_test.go
@@ -730,6 +730,56 @@ func testTransportMaxConnsPerHost(t *testing.T, mode testMode) {
 	}
 }
 
+func TestTransportMaxConnsPerHostDialCancellation(t *testing.T) {
+	run(t, testTransportMaxConnsPerHostDialCancellation,
+		testNotParallel, // because test uses SetPendingDialHooks
+		[]testMode{http1Mode, https1Mode, http2Mode},
+	)
+}
+
+func testTransportMaxConnsPerHostDialCancellation(t *testing.T, mode testMode) {
+	CondSkipHTTP2(t)
+
+	h := HandlerFunc(func(w ResponseWriter, r *Request) {
+		_, err := w.Write([]byte("foo"))
+		if err != nil {
+			t.Fatalf("Write: %v", err)
+		}
+	})
+
+	cst := newClientServerTest(t, mode, h)
+	defer cst.close()
+	ts := cst.ts
+	c := ts.Client()
+	tr := c.Transport.(*Transport)
+	tr.MaxConnsPerHost = 1
+
+	// This request is cancelled when dial is queued, which preempts dialing.
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+	SetPendingDialHooks(cancel, nil)
+	defer SetPendingDialHooks(nil, nil)
+
+	req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
+	_, err := c.Do(req)
+	if !errors.Is(err, context.Canceled) {
+		t.Errorf("expected error %v, got %v", context.Canceled, err)
+	}
+
+	// This request should succeed.
+	SetPendingDialHooks(nil, nil)
+	req, _ = NewRequest("GET", ts.URL, nil)
+	resp, err := c.Do(req)
+	if err != nil {
+		t.Fatalf("request failed: %v", err)
+	}
+	defer resp.Body.Close()
+	_, err = io.ReadAll(resp.Body)
+	if err != nil {
+		t.Fatalf("read body failed: %v", err)
+	}
+}
+
 func TestTransportRemovesDeadIdleConnections(t *testing.T) {
 	run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode})
 }