diff --git a/internal/transport/proxy.go b/internal/transport/proxy.go index 24fa1032574c..54b224436544 100644 --- a/internal/transport/proxy.go +++ b/internal/transport/proxy.go @@ -107,8 +107,14 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr stri } return nil, fmt.Errorf("failed to do connect handshake, response: %q", dump) } - - return &bufConn{Conn: conn, r: r}, nil + // The buffer could contain extra bytes from the target server, so we can't + // discard it. However, in many cases where the server waits for the client + // to send the first message (e.g. when TLS is being used), the buffer will + // be empty, so we can avoid the overhead of reading through this buffer. + if r.Buffered() != 0 { + return &bufConn{Conn: conn, r: r}, nil + } + return conn, nil } // proxyDial dials, connecting to a proxy first if necessary. Checks if a proxy diff --git a/internal/transport/proxy_test.go b/internal/transport/proxy_test.go index 8abee1e7b383..9fdd662ddd82 100644 --- a/internal/transport/proxy_test.go +++ b/internal/transport/proxy_test.go @@ -23,6 +23,7 @@ package transport import ( "bufio" + "bytes" "context" "encoding/base64" "fmt" @@ -58,7 +59,7 @@ type proxyServer struct { requestCheck func(*http.Request) error } -func (p *proxyServer) run() { +func (p *proxyServer) run(waitForServerHello bool) { in, err := p.lis.Accept() if err != nil { return @@ -83,8 +84,26 @@ func (p *proxyServer) run() { p.t.Errorf("failed to dial to server: %v", err) return } + out.SetDeadline(time.Now().Add(defaultTestTimeout)) resp := http.Response{StatusCode: http.StatusOK, Proto: "HTTP/1.0"} - resp.Write(p.in) + var buf bytes.Buffer + resp.Write(&buf) + if waitForServerHello { + // Batch the first message from the server with the http connect + // response. This is done to test the cases in which the grpc client has + // the response to the connect request and proxied packets from the + // destination server when it reads the transport. + b := make([]byte, 50) + bytesRead, err := out.Read(b) + if err != nil { + p.t.Errorf("Got error while reading server hello: %v", err) + in.Close() + out.Close() + return + } + buf.Write(b[0:bytesRead]) + } + p.in.Write(buf.Bytes()) p.out = out go io.Copy(p.in, p.out) go io.Copy(p.out, p.in) @@ -100,7 +119,13 @@ func (p *proxyServer) stop() { } } -func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxyReqCheck func(*http.Request) error) { +type testArgs struct { + proxyURLModify func(*url.URL) *url.URL + proxyReqCheck func(*http.Request) error + serverMessage []byte +} + +func testHTTPConnect(t *testing.T, args testArgs) { plis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("failed to listen: %v", err) @@ -108,9 +133,9 @@ func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxy p := &proxyServer{ t: t, lis: plis, - requestCheck: proxyReqCheck, + requestCheck: args.proxyReqCheck, } - go p.run() + go p.run(len(args.serverMessage) > 0) defer p.stop() blis, err := net.Listen("tcp", "localhost:0") @@ -128,13 +153,14 @@ func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxy return } defer in.Close() + in.Write(args.serverMessage) in.Read(recvBuf) done <- nil }() // Overwrite the function in the test and restore them in defer. hpfe := func(req *http.Request) (*url.URL, error) { - return proxyURLModify(&url.URL{Host: plis.Addr().String()}), nil + return args.proxyURLModify(&url.URL{Host: plis.Addr().String()}), nil } defer overwrite(hpfe)() @@ -143,34 +169,63 @@ func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxy defer cancel() c, err := proxyDial(ctx, blis.Addr().String(), "test") if err != nil { - t.Fatalf("http connect Dial failed: %v", err) + t.Fatalf("HTTP connect Dial failed: %v", err) } defer c.Close() + c.SetDeadline(time.Now().Add(defaultTestTimeout)) // Send msg on the connection. c.Write(msg) if err := <-done; err != nil { - t.Fatalf("failed to accept: %v", err) + t.Fatalf("Failed to accept: %v", err) } // Check received msg. if string(recvBuf) != string(msg) { - t.Fatalf("received msg: %v, want %v", recvBuf, msg) + t.Fatalf("Received msg: %v, want %v", recvBuf, msg) + } + + if len(args.serverMessage) > 0 { + gotServerMessage := make([]byte, len(args.serverMessage)) + if _, err := c.Read(gotServerMessage); err != nil { + t.Errorf("Got error while reading message from server: %v", err) + return + } + if string(gotServerMessage) != string(args.serverMessage) { + t.Errorf("Message from server: %v, want %v", gotServerMessage, args.serverMessage) + } } } func (s) TestHTTPConnect(t *testing.T) { - testHTTPConnect(t, - func(in *url.URL) *url.URL { + args := testArgs{ + proxyURLModify: func(in *url.URL) *url.URL { return in }, - func(req *http.Request) error { + proxyReqCheck: func(req *http.Request) error { if req.Method != http.MethodConnect { return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect) } return nil }, - ) + } + testHTTPConnect(t, args) +} + +func (s) TestHTTPConnectWithServerHello(t *testing.T) { + args := testArgs{ + proxyURLModify: func(in *url.URL) *url.URL { + return in + }, + proxyReqCheck: func(req *http.Request) error { + if req.Method != http.MethodConnect { + return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect) + } + return nil + }, + serverMessage: []byte("server-hello"), + } + testHTTPConnect(t, args) } func (s) TestHTTPConnectBasicAuth(t *testing.T) { @@ -178,12 +233,12 @@ func (s) TestHTTPConnectBasicAuth(t *testing.T) { user = "notAUser" password = "notAPassword" ) - testHTTPConnect(t, - func(in *url.URL) *url.URL { + args := testArgs{ + proxyURLModify: func(in *url.URL) *url.URL { in.User = url.UserPassword(user, password) return in }, - func(req *http.Request) error { + proxyReqCheck: func(req *http.Request) error { if req.Method != http.MethodConnect { return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect) } @@ -195,7 +250,8 @@ func (s) TestHTTPConnectBasicAuth(t *testing.T) { } return nil }, - ) + } + testHTTPConnect(t, args) } func (s) TestMapAddressEnv(t *testing.T) {