diff --git a/http2/client_conn_pool.go b/http2/client_conn_pool.go index 780968d6c..8f51084fc 100644 --- a/http2/client_conn_pool.go +++ b/http2/client_conn_pool.go @@ -94,9 +94,13 @@ func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMis traceGetConn(req, addr) call := p.getStartDialLocked(req.Context(), addr) p.mu.Unlock() - <-call.done - if shouldRetryDial(call, req) { - continue + select { + case <-call.done: + if shouldRetryDial(call, req) { + continue + } + case <-req.Context().Done(): + return nil, req.Context().Err() } cc, err := call.res, call.err if err != nil { diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go index 553379099..7b6fc56c6 100644 --- a/http2/clientconn_test.go +++ b/http2/clientconn_test.go @@ -10,10 +10,13 @@ package http2 import ( "bytes" "context" + "crypto/tls" "fmt" "io" + "net" "net/http" "reflect" + "sync" "sync/atomic" "testing" "time" @@ -79,6 +82,56 @@ func TestTestClientConn(t *testing.T) { rt.wantBody(nil) } +// TestConnectTimeout tests that a request does not exceed request timeout + dial timeout +func TestConnectTimeout(t *testing.T) { + tr := &Transport{ + DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + // mock a net dialler with 1s timeout, encountering network issue + // keeping dialing until timeout + var dialer = net.Dialer{Timeout: time.Duration(-1)} + select { + case <-time.After(time.Second): + case <-ctx.Done(): + } + return dialer.DialContext(ctx, network, addr) + }, + AllowHTTP: true, + } + + var sg sync.WaitGroup + parentCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for j := 0; j < 2; j++ { + sg.Add(1) + go func() { + for i := 0; i < 10000; i++ { + sg.Add(1) + go func() { + ctx, _ := context.WithTimeout(parentCtx, time.Second) + req, err := http.NewRequestWithContext(ctx, "GET", "http://127.0.0.1:80", nil) + if err != nil { + t.Errorf("NewRequest: %v", err) + } + + start := time.Now() + tr.RoundTrip(req) + duration := time.Since(start) + // duration should not exceed request timeout + dial timeout + if duration > 2*time.Second { + t.Errorf("RoundTrip took %s; want <2s", duration.String()) + } + sg.Done() + }() + time.Sleep(1 * time.Millisecond) + } + sg.Done() + }() + } + + sg.Wait() +} + // A testClientConn allows testing ClientConn.RoundTrip against a fake server. // // A test using testClientConn consists of: