Skip to content

Commit a73d8f5

Browse files
committed
net/http: make Transport send WebSocket upgrade requests over HTTP/1
WebSockets requires HTTP/1 in practice (no spec or implementations work over HTTP/2), so if we get an HTTP request that looks like it's trying to initiate WebSockets, use HTTP/1, like browsers do. This is part of a series of commits to make WebSockets work over httputil.ReverseProxy. See #26937. Updates #26937 Change-Id: I6ad3df9b0a21fddf62fa7d9cacef48e7d5d9585b Reviewed-on: https://go-review.googlesource.com/c/137437 Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Dmitri Shuralyov <dmitshur@golang.org>
1 parent 3aa3c05 commit a73d8f5

File tree

5 files changed

+68
-12
lines changed

5 files changed

+68
-12
lines changed

src/net/http/clientserver_test.go

+23-1
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ type slurpResult struct {
252252
func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) }
253253

254254
func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) {
255-
if res.Proto == wantProto {
255+
if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" {
256256
res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0
257257
} else {
258258
t.Errorf("got %q response; want %q", res.Proto, wantProto)
@@ -1546,3 +1546,25 @@ func TestBidiStreamReverseProxy(t *testing.T) {
15461546
}
15471547

15481548
}
1549+
1550+
// Always use HTTP/1.1 for WebSocket upgrades.
1551+
func TestH12_WebSocketUpgrade(t *testing.T) {
1552+
h12Compare{
1553+
Handler: func(w ResponseWriter, r *Request) {
1554+
h := w.Header()
1555+
h.Set("Foo", "bar")
1556+
},
1557+
ReqFunc: func(c *Client, url string) (*Response, error) {
1558+
req, _ := NewRequest("GET", url, nil)
1559+
req.Header.Set("Connection", "Upgrade")
1560+
req.Header.Set("Upgrade", "WebSocket")
1561+
return c.Do(req)
1562+
},
1563+
EarlyCheckResponse: func(proto string, res *Response) {
1564+
if res.Proto != "HTTP/1.1" {
1565+
t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto)
1566+
}
1567+
res.Proto = "HTTP/IGNORE" // skip later checks that Proto must be 1.1 vs 2.0
1568+
},
1569+
}.run(t)
1570+
}

src/net/http/export_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ func (t *Transport) IdleConnStrsForTesting_h2() []string {
155155
func (t *Transport) IdleConnCountForTesting(scheme, addr string) int {
156156
t.idleMu.Lock()
157157
defer t.idleMu.Unlock()
158-
key := connectMethodKey{"", scheme, addr}
158+
key := connectMethodKey{"", scheme, addr, false}
159159
cacheKey := key.String()
160160
for k, conns := range t.idleConn {
161161
if k.String() == cacheKey {
@@ -178,12 +178,12 @@ func (t *Transport) IsIdleForTesting() bool {
178178
}
179179

180180
func (t *Transport) RequestIdleConnChForTesting() {
181-
t.getIdleConnCh(connectMethod{nil, "http", "example.com"})
181+
t.getIdleConnCh(connectMethod{nil, "http", "example.com", false})
182182
}
183183

184184
func (t *Transport) PutIdleTestConn(scheme, addr string) bool {
185185
c, _ := net.Pipe()
186-
key := connectMethodKey{"", scheme, addr}
186+
key := connectMethodKey{"", scheme, addr, false}
187187
select {
188188
case <-t.incHostConnCount(key):
189189
default:

src/net/http/proxy_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func TestCacheKeys(t *testing.T) {
3535
}
3636
proxy = u
3737
}
38-
cm := connectMethod{proxy, tt.scheme, tt.addr}
38+
cm := connectMethod{proxy, tt.scheme, tt.addr, false}
3939
if got := cm.key().String(); got != tt.key {
4040
t.Fatalf("{%q, %q, %q} cache key = %q; want %q", tt.proxy, tt.scheme, tt.addr, got, tt.key)
4141
}

src/net/http/request.go

+7
Original file line numberDiff line numberDiff line change
@@ -1371,3 +1371,10 @@ func requestMethodUsuallyLacksBody(method string) bool {
13711371
}
13721372
return false
13731373
}
1374+
1375+
// requiresHTTP1 reports whether this request requires being sent on
1376+
// an HTTP/1 connection.
1377+
func (r *Request) requiresHTTP1() bool {
1378+
return hasToken(r.Header.Get("Connection"), "upgrade") &&
1379+
strings.EqualFold(r.Header.Get("Upgrade"), "websocket")
1380+
}

src/net/http/transport.go

+34-7
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,19 @@ func (tr *transportRequest) setError(err error) {
382382
tr.mu.Unlock()
383383
}
384384

385+
// useRegisteredProtocol reports whether an alternate protocol (as reqistered
386+
// with Transport.RegisterProtocol) should be respected for this request.
387+
func (t *Transport) useRegisteredProtocol(req *Request) bool {
388+
if req.URL.Scheme == "https" && req.requiresHTTP1() {
389+
// If this request requires HTTP/1, don't use the
390+
// "https" alternate protocol, which is used by the
391+
// HTTP/2 code to take over requests if there's an
392+
// existing cached HTTP/2 connection.
393+
return false
394+
}
395+
return true
396+
}
397+
385398
// roundTrip implements a RoundTripper over HTTP.
386399
func (t *Transport) roundTrip(req *Request) (*Response, error) {
387400
t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
@@ -411,10 +424,12 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
411424
}
412425
}
413426

414-
altProto, _ := t.altProto.Load().(map[string]RoundTripper)
415-
if altRT := altProto[scheme]; altRT != nil {
416-
if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol {
417-
return resp, err
427+
if t.useRegisteredProtocol(req) {
428+
altProto, _ := t.altProto.Load().(map[string]RoundTripper)
429+
if altRT := altProto[scheme]; altRT != nil {
430+
if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol {
431+
return resp, err
432+
}
418433
}
419434
}
420435
if !isHTTP {
@@ -653,6 +668,7 @@ func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectM
653668
}
654669
}
655670
}
671+
cm.onlyH1 = treq.requiresHTTP1()
656672
return cm, err
657673
}
658674

@@ -1155,6 +1171,9 @@ func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) erro
11551171
if cfg.ServerName == "" {
11561172
cfg.ServerName = name
11571173
}
1174+
if pconn.cacheKey.onlyH1 {
1175+
cfg.NextProtos = nil
1176+
}
11581177
plainConn := pconn.conn
11591178
tlsConn := tls.Client(plainConn, cfg)
11601179
errc := make(chan error, 2)
@@ -1361,10 +1380,11 @@ func (w persistConnWriter) Write(p []byte) (n int, err error) {
13611380
//
13621381
// A connect method may be of the following types:
13631382
//
1364-
// Cache key form Description
1365-
// ----------------- -------------------------
1383+
// connectMethod.key().String() Description
1384+
// ------------------------------ -------------------------
13661385
// |http|foo.com http directly to server, no proxy
13671386
// |https|foo.com https directly to server, no proxy
1387+
// |https,h1|foo.com https directly to server w/o HTTP/2, no proxy
13681388
// http://proxy.com|https|foo.com http to proxy, then CONNECT to foo.com
13691389
// http://proxy.com|http http to proxy, http to anywhere after that
13701390
// socks5://proxy.com|http|foo.com socks5 to proxy, then http to foo.com
@@ -1379,6 +1399,7 @@ type connectMethod struct {
13791399
// then targetAddr is not included in the connect method key, because the socket can
13801400
// be reused for different targetAddr values.
13811401
targetAddr string
1402+
onlyH1 bool // whether to disable HTTP/2 and force HTTP/1
13821403
}
13831404

13841405
func (cm *connectMethod) key() connectMethodKey {
@@ -1394,6 +1415,7 @@ func (cm *connectMethod) key() connectMethodKey {
13941415
proxy: proxyStr,
13951416
scheme: cm.targetScheme,
13961417
addr: targetAddr,
1418+
onlyH1: cm.onlyH1,
13971419
}
13981420
}
13991421

@@ -1428,11 +1450,16 @@ func (cm *connectMethod) tlsHost() string {
14281450
// a URL.
14291451
type connectMethodKey struct {
14301452
proxy, scheme, addr string
1453+
onlyH1 bool
14311454
}
14321455

14331456
func (k connectMethodKey) String() string {
14341457
// Only used by tests.
1435-
return fmt.Sprintf("%s|%s|%s", k.proxy, k.scheme, k.addr)
1458+
var h1 string
1459+
if k.onlyH1 {
1460+
h1 = ",h1"
1461+
}
1462+
return fmt.Sprintf("%s|%s%s|%s", k.proxy, k.scheme, h1, k.addr)
14361463
}
14371464

14381465
// persistConn wraps a connection, usually a persistent one

0 commit comments

Comments
 (0)