Skip to content

Commit

Permalink
transport: Discard the buffer when empty after http connect handshake (
Browse files Browse the repository at this point in the history
…grpc#7424)

* Discard the buffer when empty after http connect handshake

* configure the proxy to wait for server hello

* Extract test args to a struct

* Change deadline sets
  • Loading branch information
arjan-bal authored and printchard committed Jul 30, 2024
1 parent 936aceb commit 01090b4
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 19 deletions.
10 changes: 8 additions & 2 deletions internal/transport/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,14 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr stri
}
return nil, fmt.Errorf("failed to do connect handshake, response: %q", dump)
}

return &bufConn{Conn: conn, r: r}, nil
// The buffer could contain extra bytes from the target server, so we can't
// discard it. However, in many cases where the server waits for the client
// to send the first message (e.g. when TLS is being used), the buffer will
// be empty, so we can avoid the overhead of reading through this buffer.
if r.Buffered() != 0 {
return &bufConn{Conn: conn, r: r}, nil
}
return conn, nil
}

// proxyDial dials, connecting to a proxy first if necessary. Checks if a proxy
Expand Down
90 changes: 73 additions & 17 deletions internal/transport/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package transport

import (
"bufio"
"bytes"
"context"
"encoding/base64"
"fmt"
Expand Down Expand Up @@ -58,7 +59,7 @@ type proxyServer struct {
requestCheck func(*http.Request) error
}

func (p *proxyServer) run() {
func (p *proxyServer) run(waitForServerHello bool) {
in, err := p.lis.Accept()
if err != nil {
return
Expand All @@ -83,8 +84,26 @@ func (p *proxyServer) run() {
p.t.Errorf("failed to dial to server: %v", err)
return
}
out.SetDeadline(time.Now().Add(defaultTestTimeout))
resp := http.Response{StatusCode: http.StatusOK, Proto: "HTTP/1.0"}
resp.Write(p.in)
var buf bytes.Buffer
resp.Write(&buf)
if waitForServerHello {
// Batch the first message from the server with the http connect
// response. This is done to test the cases in which the grpc client has
// the response to the connect request and proxied packets from the
// destination server when it reads the transport.
b := make([]byte, 50)
bytesRead, err := out.Read(b)
if err != nil {
p.t.Errorf("Got error while reading server hello: %v", err)
in.Close()
out.Close()
return
}
buf.Write(b[0:bytesRead])
}
p.in.Write(buf.Bytes())
p.out = out
go io.Copy(p.in, p.out)
go io.Copy(p.out, p.in)
Expand All @@ -100,17 +119,23 @@ func (p *proxyServer) stop() {
}
}

func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxyReqCheck func(*http.Request) error) {
type testArgs struct {
proxyURLModify func(*url.URL) *url.URL
proxyReqCheck func(*http.Request) error
serverMessage []byte
}

func testHTTPConnect(t *testing.T, args testArgs) {
plis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
p := &proxyServer{
t: t,
lis: plis,
requestCheck: proxyReqCheck,
requestCheck: args.proxyReqCheck,
}
go p.run()
go p.run(len(args.serverMessage) > 0)
defer p.stop()

blis, err := net.Listen("tcp", "localhost:0")
Expand All @@ -128,13 +153,14 @@ func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxy
return
}
defer in.Close()
in.Write(args.serverMessage)
in.Read(recvBuf)
done <- nil
}()

// Overwrite the function in the test and restore them in defer.
hpfe := func(req *http.Request) (*url.URL, error) {
return proxyURLModify(&url.URL{Host: plis.Addr().String()}), nil
return args.proxyURLModify(&url.URL{Host: plis.Addr().String()}), nil
}
defer overwrite(hpfe)()

Expand All @@ -143,47 +169,76 @@ func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxy
defer cancel()
c, err := proxyDial(ctx, blis.Addr().String(), "test")
if err != nil {
t.Fatalf("http connect Dial failed: %v", err)
t.Fatalf("HTTP connect Dial failed: %v", err)
}
defer c.Close()
c.SetDeadline(time.Now().Add(defaultTestTimeout))

// Send msg on the connection.
c.Write(msg)
if err := <-done; err != nil {
t.Fatalf("failed to accept: %v", err)
t.Fatalf("Failed to accept: %v", err)
}

// Check received msg.
if string(recvBuf) != string(msg) {
t.Fatalf("received msg: %v, want %v", recvBuf, msg)
t.Fatalf("Received msg: %v, want %v", recvBuf, msg)
}

if len(args.serverMessage) > 0 {
gotServerMessage := make([]byte, len(args.serverMessage))
if _, err := c.Read(gotServerMessage); err != nil {
t.Errorf("Got error while reading message from server: %v", err)
return
}
if string(gotServerMessage) != string(args.serverMessage) {
t.Errorf("Message from server: %v, want %v", gotServerMessage, args.serverMessage)
}
}
}

func (s) TestHTTPConnect(t *testing.T) {
testHTTPConnect(t,
func(in *url.URL) *url.URL {
args := testArgs{
proxyURLModify: func(in *url.URL) *url.URL {
return in
},
func(req *http.Request) error {
proxyReqCheck: func(req *http.Request) error {
if req.Method != http.MethodConnect {
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
}
return nil
},
)
}
testHTTPConnect(t, args)
}

func (s) TestHTTPConnectWithServerHello(t *testing.T) {
args := testArgs{
proxyURLModify: func(in *url.URL) *url.URL {
return in
},
proxyReqCheck: func(req *http.Request) error {
if req.Method != http.MethodConnect {
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
}
return nil
},
serverMessage: []byte("server-hello"),
}
testHTTPConnect(t, args)
}

func (s) TestHTTPConnectBasicAuth(t *testing.T) {
const (
user = "notAUser"
password = "notAPassword"
)
testHTTPConnect(t,
func(in *url.URL) *url.URL {
args := testArgs{
proxyURLModify: func(in *url.URL) *url.URL {
in.User = url.UserPassword(user, password)
return in
},
func(req *http.Request) error {
proxyReqCheck: func(req *http.Request) error {
if req.Method != http.MethodConnect {
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
}
Expand All @@ -195,7 +250,8 @@ func (s) TestHTTPConnectBasicAuth(t *testing.T) {
}
return nil
},
)
}
testHTTPConnect(t, args)
}

func (s) TestMapAddressEnv(t *testing.T) {
Expand Down

0 comments on commit 01090b4

Please sign in to comment.