diff --git a/clientconn.go b/clientconn.go index 4f57b55434f9..030a7ef10318 100644 --- a/clientconn.go +++ b/clientconn.go @@ -225,7 +225,12 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) { func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) { // At the end of this method, we kick the channel out of idle, rather than // waiting for the first rpc. - opts = append([]DialOption{withDefaultScheme("passthrough")}, opts...) + // + // WithLocalDNSResolution dial option in `grpc.Dial` ensures that it + // preserves behavior: when default scheme passthrough is used, skip + // hostname resolution, when "dns" is used for resolution, perform + // resolution on the client. + opts = append([]DialOption{withDefaultScheme("passthrough"), WithLocalDNSResolution()}, opts...) cc, err := NewClient(target, opts...) if err != nil { return nil, err diff --git a/dialoptions.go b/dialoptions.go index f3a045296a46..e565365fee19 100644 --- a/dialoptions.go +++ b/dialoptions.go @@ -94,6 +94,8 @@ type dialOptions struct { idleTimeout time.Duration defaultScheme string maxCallAttempts int + enableLocalDNSResolution bool // Specifies if target hostnames should be resolved when proxying is enabled. + useProxy bool // Specifies if a server should be connected via proxy. } // DialOption configures how we set up the connection. @@ -377,7 +379,22 @@ func WithInsecure() DialOption { // later release. func WithNoProxy() DialOption { return newFuncDialOption(func(o *dialOptions) { - o.copts.UseProxy = false + o.useProxy = false + }) +} + +// WithLocalDNSResolution forces local DNS name resolution even when a proxy is +// specified in the environment. By default, the server name is provided +// directly to the proxy as part of the CONNECT handshake. This is ignored if +// WithNoProxy is used. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a +// later release. +func WithLocalDNSResolution() DialOption { + return newFuncDialOption(func(o *dialOptions) { + o.enableLocalDNSResolution = true }) } @@ -667,14 +684,15 @@ func defaultDialOptions() dialOptions { copts: transport.ConnectOptions{ ReadBufferSize: defaultReadBufSize, WriteBufferSize: defaultWriteBufSize, - UseProxy: true, UserAgent: grpcUA, BufferPool: mem.DefaultBufferPool(), }, - bs: internalbackoff.DefaultExponential, - idleTimeout: 30 * time.Minute, - defaultScheme: "dns", - maxCallAttempts: defaultMaxCallAttempts, + bs: internalbackoff.DefaultExponential, + idleTimeout: 30 * time.Minute, + defaultScheme: "dns", + maxCallAttempts: defaultMaxCallAttempts, + useProxy: true, + enableLocalDNSResolution: false, } } diff --git a/internal/proxyattributes/proxyattributes.go b/internal/proxyattributes/proxyattributes.go index d25d33efd373..1f61f1a49d7a 100644 --- a/internal/proxyattributes/proxyattributes.go +++ b/internal/proxyattributes/proxyattributes.go @@ -33,7 +33,7 @@ const proxyOptionsKey = keyType("grpc.resolver.delegatingresolver.proxyOptions") // Options holds the proxy connection details needed during the CONNECT // handshake. type Options struct { - User url.Userinfo + User *url.Userinfo ConnectAddr string } @@ -44,7 +44,8 @@ func Set(addr resolver.Address, opts Options) resolver.Address { } // Get returns the Options for the proxy [resolver.Address] and a boolean -// value representing if the attribute is present or not. +// value representing if the attribute is present or not. The returned data +// should not be mutated. func Get(addr resolver.Address) (Options, bool) { if a := addr.Attributes.Value(proxyOptionsKey); a != nil { return a.(Options), true diff --git a/internal/proxyattributes/proxyattributes_test.go b/internal/proxyattributes/proxyattributes_test.go index 225b2919d5d9..2c938c396160 100644 --- a/internal/proxyattributes/proxyattributes_test.go +++ b/internal/proxyattributes/proxyattributes_test.go @@ -42,7 +42,7 @@ func (s) TestGet(t *testing.T) { name string addr resolver.Address wantConnectAddr string - wantUser url.Userinfo + wantUser *url.Userinfo wantAttrPresent bool }{ { @@ -61,10 +61,10 @@ func (s) TestGet(t *testing.T) { addr: resolver.Address{ Addr: "test-address", Attributes: attributes.New(proxyOptionsKey, Options{ - User: *user, + User: user, }), }, - wantUser: *user, + wantUser: user, wantAttrPresent: true, }, { @@ -97,7 +97,7 @@ func (s) TestGet(t *testing.T) { func (s) TestSet(t *testing.T) { addr := resolver.Address{Addr: "test-address"} pOpts := Options{ - User: *url.UserPassword("username", "password"), + User: url.UserPassword("username", "password"), ConnectAddr: "proxy-address", } @@ -108,7 +108,7 @@ func (s) TestSet(t *testing.T) { t.Errorf("Get(%v) = %v, want %v ", populatedAddr, attrPresent, true) } if got, want := gotOption.ConnectAddr, pOpts.ConnectAddr; got != want { - t.Errorf("Unexpected ConnectAddr proxy atrribute = %v, want %v", got, want) + t.Errorf("unexpected ConnectAddr proxy atrribute = %v, want %v", got, want) } if got, want := gotOption.User, pOpts.User; got != want { t.Errorf("unexpected User proxy attribute = %v, want %v", got, want) diff --git a/internal/resolver/delegatingresolver/delegatingresolver.go b/internal/resolver/delegatingresolver/delegatingresolver.go index 6050e3d055bb..a6c647013388 100644 --- a/internal/resolver/delegatingresolver/delegatingresolver.go +++ b/internal/resolver/delegatingresolver/delegatingresolver.go @@ -205,13 +205,9 @@ func (r *delegatingResolver) updateClientConnStateLocked() error { proxyAddr = resolver.Address{Addr: r.proxyURL.Host} } var addresses []resolver.Address - var user url.Userinfo - if r.proxyURL.User != nil { - user = *r.proxyURL.User - } for _, targetAddr := range (*r.targetResolverState).Addresses { addresses = append(addresses, proxyattributes.Set(proxyAddr, proxyattributes.Options{ - User: user, + User: r.proxyURL.User, ConnectAddr: targetAddr.Addr, })) } @@ -229,7 +225,7 @@ func (r *delegatingResolver) updateClientConnStateLocked() error { for _, proxyAddr := range r.proxyAddrs { for _, targetAddr := range endpt.Addresses { addrs = append(addrs, proxyattributes.Set(proxyAddr, proxyattributes.Options{ - User: user, + User: r.proxyURL.User, ConnectAddr: targetAddr.Addr, })) } diff --git a/internal/testutils/proxyserver/proxyserver.go b/internal/testutils/proxyserver/proxyserver.go new file mode 100644 index 000000000000..576ab11ad1d0 --- /dev/null +++ b/internal/testutils/proxyserver/proxyserver.go @@ -0,0 +1,134 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package proxyserver provides an implementation of a proxy server for testing purposes. +// The server supports only a single incoming connection at a time and is not concurrent. +// It handles only HTTP CONNECT requests; other HTTP methods are not supported. +package proxyserver + +import ( + "bufio" + "bytes" + "io" + "net" + "net/http" + "testing" + "time" + + "google.golang.org/grpc/internal/testutils" +) + +// ProxyServer represents a test proxy server. +type ProxyServer struct { + lis net.Listener + in net.Conn // Connection from the client to the proxy. + out net.Conn // Connection from the proxy to the backend. + onRequest func(*http.Request) // Function to check the request sent to proxy. + Addr string // Address of the proxy +} + +const defaultTestTimeout = 10 * time.Second + +// Stop closes the ProxyServer and its connections to client and server. +func (p *ProxyServer) stop() { + p.lis.Close() + if p.in != nil { + p.in.Close() + } + if p.out != nil { + p.out.Close() + } +} + +func (p *ProxyServer) handleRequest(t *testing.T, in net.Conn, waitForServerHello bool) { + req, err := http.ReadRequest(bufio.NewReader(in)) + if err != nil { + t.Errorf("failed to read CONNECT req: %v", err) + return + } + if req.Method != http.MethodConnect { + t.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect) + } + p.onRequest(req) + + t.Logf("Dialing to %s", req.URL.Host) + out, err := net.Dial("tcp", req.URL.Host) + if err != nil { + in.Close() + t.Logf("failed to dial to server: %v", err) + return + } + out.SetDeadline(time.Now().Add(defaultTestTimeout)) + resp := http.Response{StatusCode: http.StatusOK, Proto: "HTTP/1.0"} + var buf bytes.Buffer + resp.Write(&buf) + + if waitForServerHello { + // Batch the first message from the server with the http connect + // response. This is done to test the cases in which the grpc client has + // the response to the connect request and proxied packets from the + // destination server when it reads the transport. + b := make([]byte, 50) + bytesRead, err := out.Read(b) + if err != nil { + t.Errorf("Got error while reading server hello: %v", err) + in.Close() + out.Close() + return + } + buf.Write(b[0:bytesRead]) + } + p.in = in + p.in.Write(buf.Bytes()) + p.out = out + + go io.Copy(p.in, p.out) + go io.Copy(p.out, p.in) +} + +// New initializes and starts a proxy server, registers a cleanup to +// stop it, and returns a ProxyServer. +func New(t *testing.T, reqCheck func(*http.Request), waitForServerHello bool) *ProxyServer { + t.Helper() + pLis, err := testutils.LocalTCPListener() + if err != nil { + t.Fatalf("failed to listen: %v", err) + } + + p := &ProxyServer{ + lis: pLis, + onRequest: reqCheck, + Addr: pLis.Addr().String(), + } + + // Start the proxy server. + go func() { + for { + in, err := p.lis.Accept() + if err != nil { + return + } + // p.handleRequest is not invoked in a goroutine because the test + // proxy currently supports handling only one connection at a time. + p.handleRequest(t, in, waitForServerHello) + } + }() + t.Logf("Started proxy at: %q", pLis.Addr().String()) + t.Cleanup(p.stop) + return p +} diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index f323ab7f45a6..513dbb93d550 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -43,6 +43,7 @@ import ( "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpcutil" imetadata "google.golang.org/grpc/internal/metadata" + "google.golang.org/grpc/internal/proxyattributes" istatus "google.golang.org/grpc/internal/status" isyscall "google.golang.org/grpc/internal/syscall" "google.golang.org/grpc/internal/transport/networktype" @@ -153,7 +154,7 @@ type http2Client struct { logger *grpclog.PrefixLogger } -func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr resolver.Address, useProxy bool, grpcUA string) (net.Conn, error) { +func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr resolver.Address, grpcUA string) (net.Conn, error) { address := addr.Addr networkType, ok := networktype.Get(addr) if fn != nil { @@ -177,8 +178,8 @@ func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error if !ok { networkType, address = parseDialTarget(address) } - if networkType == "tcp" && useProxy { - return proxyDial(ctx, address, grpcUA) + if opts, present := proxyattributes.Get(addr); present { + return proxyDial(ctx, addr, grpcUA, opts) } return internal.NetDialerWithTCPKeepalive().DialContext(ctx, networkType, address) } @@ -217,7 +218,7 @@ func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts // address specific arbitrary data to reach custom dialers and credential handshakers. connectCtx = icredentials.NewClientHandshakeInfoContext(connectCtx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) - conn, err := dial(connectCtx, opts.Dialer, addr, opts.UseProxy, opts.UserAgent) + conn, err := dial(connectCtx, opts.Dialer, addr, opts.UserAgent) if err != nil { if opts.FailOnNonTempDialError { return nil, connectionErrorf(isTemporary(err), err, "transport: error while dialing: %v", err) diff --git a/internal/transport/proxy.go b/internal/transport/proxy.go index 54b224436544..d7738459550b 100644 --- a/internal/transport/proxy.go +++ b/internal/transport/proxy.go @@ -30,34 +30,16 @@ import ( "net/url" "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/proxyattributes" + "google.golang.org/grpc/resolver" ) const proxyAuthHeaderKey = "Proxy-Authorization" -var ( - // The following variable will be overwritten in the tests. - httpProxyFromEnvironment = http.ProxyFromEnvironment -) - -func mapAddress(address string) (*url.URL, error) { - req := &http.Request{ - URL: &url.URL{ - Scheme: "https", - Host: address, - }, - } - url, err := httpProxyFromEnvironment(req) - if err != nil { - return nil, err - } - return url, nil -} - // To read a response from a net.Conn, http.ReadResponse() takes a bufio.Reader. -// It's possible that this reader reads more than what's need for the response and stores -// those bytes in the buffer. -// bufConn wraps the original net.Conn and the bufio.Reader to make sure we don't lose the -// bytes in the buffer. +// It's possible that this reader reads more than what's need for the response +// and stores those bytes in the buffer. bufConn wraps the original net.Conn +// and the bufio.Reader to make sure we don't lose the bytes in the buffer. type bufConn struct { net.Conn r io.Reader @@ -72,7 +54,7 @@ func basicAuth(username, password string) string { return base64.StdEncoding.EncodeToString([]byte(auth)) } -func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr string, proxyURL *url.URL, grpcUA string) (_ net.Conn, err error) { +func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, grpcUA string, opts proxyattributes.Options) (_ net.Conn, err error) { defer func() { if err != nil { conn.Close() @@ -81,15 +63,14 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr stri req := &http.Request{ Method: http.MethodConnect, - URL: &url.URL{Host: backendAddr}, + URL: &url.URL{Host: opts.ConnectAddr}, Header: map[string][]string{"User-Agent": {grpcUA}}, } - if t := proxyURL.User; t != nil { - u := t.Username() - p, _ := t.Password() + if user := opts.User; user != nil { + u := user.Username() + p, _ := user.Password() req.Header.Add(proxyAuthHeaderKey, "Basic "+basicAuth(u, p)) } - if err := sendHTTPRequest(ctx, req, conn); err != nil { return nil, fmt.Errorf("failed to write the HTTP request: %v", err) } @@ -117,28 +98,13 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr stri return conn, nil } -// proxyDial dials, connecting to a proxy first if necessary. Checks if a proxy -// is necessary, dials, does the HTTP CONNECT handshake, and returns the -// connection. -func proxyDial(ctx context.Context, addr string, grpcUA string) (net.Conn, error) { - newAddr := addr - proxyURL, err := mapAddress(addr) - if err != nil { - return nil, err - } - if proxyURL != nil { - newAddr = proxyURL.Host - } - - conn, err := internal.NetDialerWithTCPKeepalive().DialContext(ctx, "tcp", newAddr) +// proxyDial establishes a TCP connection to the specified address and performs an HTTP CONNECT handshake. +func proxyDial(ctx context.Context, addr resolver.Address, grpcUA string, opts proxyattributes.Options) (net.Conn, error) { + conn, err := internal.NetDialerWithTCPKeepalive().DialContext(ctx, "tcp", addr.Addr) if err != nil { return nil, err } - if proxyURL == nil { - // proxy is disabled if proxyURL is nil. - return conn, err - } - return doHTTPConnectHandshake(ctx, conn, addr, proxyURL, grpcUA) + return doHTTPConnectHandshake(ctx, conn, grpcUA, opts) } func sendHTTPRequest(ctx context.Context, req *http.Request, conn net.Conn) error { diff --git a/internal/transport/proxy_ext_test.go b/internal/transport/proxy_ext_test.go new file mode 100644 index 000000000000..e53b506013bc --- /dev/null +++ b/internal/transport/proxy_ext_test.go @@ -0,0 +1,518 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package transport_test + +import ( + "context" + "encoding/base64" + "fmt" + "net" + "net/http" + "net/netip" + "net/url" + "testing" + "time" + + "golang.org/x/net/http/httpproxy" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/resolver/delegatingresolver" + "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/testutils/proxyserver" + testgrpc "google.golang.org/grpc/interop/grpc_testing" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/resolver/manual" +) + +const defaultTestTimeout = 10 * time.Second + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +func startBackendServer(t *testing.T) *stubserver.StubServer { + t.Helper() + backend := &stubserver.StubServer{ + EmptyCallF: func(context.Context, *testgrpc.Empty) (*testgrpc.Empty, error) { return &testgrpc.Empty{}, nil }, + } + if err := backend.StartServer(); err != nil { + t.Fatalf("failed to start backend: %v", err) + } + t.Logf("Started TestService backend at: %q", backend.Address) + t.Cleanup(backend.Stop) + return backend +} + +func isIPAddr(addr string) bool { + _, err := netip.ParseAddr(addr) + return err == nil +} + +// Tests the scenario where grpc.Dial is performed using a proxy with the +// default resolver in the target URI. The test verifies that the connection is +// established to the proxy server, sends the unresolved target URI in the HTTP +// CONNECT request and is successfully connected to the backend server. +func (s) TestGRPCDialWithProxy(t *testing.T) { + backend := startBackendServer(t) + unresolvedTargetURI := fmt.Sprintf("localhost:%d", testutils.ParsePort(t, backend.Address)) + proxyCalled := false + reqCheck := func(req *http.Request) { + proxyCalled = true + host, _, err := net.SplitHostPort(req.URL.Host) + if err != nil { + t.Error(err) + } + if got, want := host, "localhost"; got != want { + t.Errorf(" Unexpected request host: %s , want = %s ", got, want) + } + } + pServer := proxyserver.New(t, reqCheck, false) + // Use "localhost:" to verify the proxy address is handled + // correctly by the delegating resolver and connects to the proxy server + // correctly even when unresolved. + pAddr := fmt.Sprintf("localhost:%d", testutils.ParsePort(t, pServer.Addr)) + + // Overwrite the function in the test and restore them in defer. + hpfe := func(req *http.Request) (*url.URL, error) { + if req.URL.Host == unresolvedTargetURI { + return &url.URL{ + Scheme: "https", + Host: pAddr, + }, nil + } + t.Errorf("Unexpected request host to proxy: %s want %s", req.URL.Host, unresolvedTargetURI) + return nil, nil + } + orighpfe := delegatingresolver.HTTPSProxyFromEnvironment + delegatingresolver.HTTPSProxyFromEnvironment = hpfe + defer func() { delegatingresolver.HTTPSProxyFromEnvironment = orighpfe }() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + conn, err := grpc.Dial(unresolvedTargetURI, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("grpc.Dial(%s) failed: %v", unresolvedTargetURI, err) + } + defer conn.Close() + + // Send an empty RPC to the backend through the proxy. + client := testgrpc.NewTestServiceClient(conn) + if _, err := client.EmptyCall(ctx, &testgrpc.Empty{}); err != nil { + t.Fatalf("EmptyCall failed: %v", err) + } + + if !proxyCalled { + t.Fatalf("Proxy not connected") + } +} + +// Tests the scenario where `grpc.Dial` is performed with a proxy and the "dns" +// scheme for the target. The test verifies that the proxy URI is correctly +// resolved and that the target URI resolution on the client preserves the +// original behavior of `grpc.Dial`. It also ensures that a connection is +// established to the proxy server, with the resolved target URI sent in the +// HTTP CONNECT request, successfully connecting to the backend server. +func (s) TestGRPCDialWithDNSAndProxy(t *testing.T) { + backend := startBackendServer(t) + unresolvedTargetURI := fmt.Sprintf("localhost:%d", testutils.ParsePort(t, backend.Address)) + proxyCalled := false + reqCheck := func(req *http.Request) { + proxyCalled = true + + host, _, err := net.SplitHostPort(req.URL.Host) + if err != nil { + t.Error(err) + } + if got, want := isIPAddr(host), true; got != want { + t.Errorf("isIPAddr(%q) = %t, want = %t", host, got, want) + } + } + pServer := proxyserver.New(t, reqCheck, false) + + // Overwrite the function in the test and restore them in defer. + hpfe := func(req *http.Request) (*url.URL, error) { + if req.URL.Host == unresolvedTargetURI { + return &url.URL{ + Scheme: "https", + Host: pServer.Addr, + }, nil + } + t.Errorf("Unexpected request host to proxy: %s want %s", req.URL.Host, unresolvedTargetURI) + return nil, nil + } + orighpfe := delegatingresolver.HTTPSProxyFromEnvironment + delegatingresolver.HTTPSProxyFromEnvironment = hpfe + defer func() { delegatingresolver.HTTPSProxyFromEnvironment = orighpfe }() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + conn, err := grpc.Dial("dns:///"+unresolvedTargetURI, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("grpc.Dial(%s) failed: %v", "dns:///"+unresolvedTargetURI, err) + } + defer conn.Close() + + // Send an empty RPC to the backend through the proxy. + client := testgrpc.NewTestServiceClient(conn) + if _, err := client.EmptyCall(ctx, &testgrpc.Empty{}); err != nil { + t.Fatalf("EmptyCall failed: %v", err) + } + + if !proxyCalled { + t.Fatalf("Proxy not connected") + } +} + +// Tests the scenario where `grpc.NewClient` is used with the default DNS +// resolver for the target URI and a proxy is configured. The test verifies +// that the client resolves proxy URI, connects to the proxy server, sends the +// unresolved target URI in the HTTP CONNECT request, and successfully +// establishes a connection to the backend server. +func (s) TestNewClientWithProxy(t *testing.T) { + backend := startBackendServer(t) + unresolvedTargetURI := fmt.Sprintf("localhost:%d", testutils.ParsePort(t, backend.Address)) + proxyCalled := false + reqCheck := func(req *http.Request) { + proxyCalled = true + host, _, err := net.SplitHostPort(req.URL.Host) + if err != nil { + t.Error(err) + } + if got, want := host, "localhost"; got != want { + t.Errorf(" Unexpected request host: %s , want = %s ", got, want) + } + } + pServer := proxyserver.New(t, reqCheck, false) + // Use "localhost:" to verify the proxy address is handled + // correctly by the delegating resolver and connects to the proxy server + // correctly even when unresolved. + pAddr := fmt.Sprintf("localhost:%d", testutils.ParsePort(t, pServer.Addr)) + + // Overwrite the function in the test and restore them in defer. + hpfe := func(req *http.Request) (*url.URL, error) { + if req.URL.Host == unresolvedTargetURI { + return &url.URL{ + Scheme: "https", + Host: pAddr, + }, nil + } + t.Errorf("Unexpected request host to proxy: %s want %s", req.URL.Host, unresolvedTargetURI) + return nil, nil + } + orighpfe := delegatingresolver.HTTPSProxyFromEnvironment + delegatingresolver.HTTPSProxyFromEnvironment = hpfe + defer func() { delegatingresolver.HTTPSProxyFromEnvironment = orighpfe }() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + conn, err := grpc.NewClient(unresolvedTargetURI, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("grpc.NewClient(%s) failed: %v", unresolvedTargetURI, err) + } + defer conn.Close() + + // Send an empty RPC to the backend through the proxy. + client := testgrpc.NewTestServiceClient(conn) + if _, err := client.EmptyCall(ctx, &testgrpc.Empty{}); err != nil { + t.Fatalf("EmptyCall failed: %v", err) + } + if !proxyCalled { + t.Fatalf("Proxy not connected") + } +} + +// Tests the scenario where grpc.NewClient is used with a custom target URI +// scheme and a proxy is configured. The test verifies that the client +// successfully connects to the proxy server, resolves the proxy URI correctly, +// includes the resolved target URI in the HTTP CONNECT request, and +// establishes a connection to the backend server. +func (s) TestNewClientWithProxyAndCustomResolver(t *testing.T) { + backend := startBackendServer(t) + unresolvedTargetURI := fmt.Sprintf("localhost:%d", testutils.ParsePort(t, backend.Address)) + proxyCalled := false + reqCheck := func(req *http.Request) { + proxyCalled = true + host, _, err := net.SplitHostPort(req.URL.Host) + if err != nil { + t.Error(err) + } + if got, want := isIPAddr(host), true; got != want { + t.Errorf("isIPAddr(%q) = %t, want = %t", host, got, want) + } + } + pServer := proxyserver.New(t, reqCheck, false) + + // Overwrite the function in the test and restore them in defer. + hpfe := func(req *http.Request) (*url.URL, error) { + if req.URL.Host == unresolvedTargetURI { + return &url.URL{ + Scheme: "https", + Host: pServer.Addr, + }, nil + } + t.Errorf("Unexpected request host to proxy: %s want %s", req.URL.Host, unresolvedTargetURI) + return nil, nil + } + orighpfe := delegatingresolver.HTTPSProxyFromEnvironment + delegatingresolver.HTTPSProxyFromEnvironment = hpfe + defer func() { delegatingresolver.HTTPSProxyFromEnvironment = orighpfe }() + + // Create and update a custom resolver for target URI. + targetResolver := manual.NewBuilderWithScheme("test") + resolver.Register(targetResolver) + targetResolver.InitialState(resolver.State{Endpoints: []resolver.Endpoint{{Addresses: []resolver.Address{{Addr: backend.Address}}}}}) + + // Dial to the proxy server. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + conn, err := grpc.NewClient(targetResolver.Scheme()+":///"+unresolvedTargetURI, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("grpc.NewClient(%s) failed: %v", targetResolver.Scheme()+":///"+unresolvedTargetURI, err) + } + defer conn.Close() + + // Send an empty RPC to the backend through the proxy. + client := testgrpc.NewTestServiceClient(conn) + if _, err := client.EmptyCall(ctx, &testgrpc.Empty{}); err != nil { + t.Fatalf("EmptyCall() failed: %v", err) + } + + if !proxyCalled { + t.Fatalf("Proxy not connected") + } +} + +// Tests the scenario where grpc.NewClient is used with the default "dns" +// resolver and the dial option grpc.WithLocalDNSResolution() is set, +// enabling target resolution on the client. The test verifies that target +// resolution happens on the client by sending resolved target URI in HTTP +// CONNECT request, the proxy URI is resolved correctly, and the connection is +// successfully established with the backend server through the proxy. +func (s) TestNewClientWithProxyAndTargetResolutionEnabled(t *testing.T) { + backend := startBackendServer(t) + unresolvedTargetURI := fmt.Sprintf("localhost:%d", testutils.ParsePort(t, backend.Address)) + proxyCalled := false + reqCheck := func(req *http.Request) { + proxyCalled = true + host, _, err := net.SplitHostPort(req.URL.Host) + if err != nil { + t.Error(err) + } + if got, want := isIPAddr(host), true; got != want { + t.Errorf("isIPAddr(%q) = %t, want = %t", host, got, want) + } + } + pServer := proxyserver.New(t, reqCheck, false) + + // Overwrite the function in the test and restore them in defer. + hpfe := func(req *http.Request) (*url.URL, error) { + if req.URL.Host == unresolvedTargetURI { + return &url.URL{ + Scheme: "https", + Host: pServer.Addr, + }, nil + } + t.Errorf("Unexpected request host to proxy: %s want %s", req.URL.Host, unresolvedTargetURI) + return nil, nil + } + orighpfe := delegatingresolver.HTTPSProxyFromEnvironment + delegatingresolver.HTTPSProxyFromEnvironment = hpfe + defer func() { delegatingresolver.HTTPSProxyFromEnvironment = orighpfe }() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + conn, err := grpc.NewClient(unresolvedTargetURI, grpc.WithLocalDNSResolution(), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("grpc.NewClient(%s) failed: %v", unresolvedTargetURI, err) + } + defer conn.Close() + + // Send an empty RPC to the backend through the proxy. + client := testgrpc.NewTestServiceClient(conn) + if _, err := client.EmptyCall(ctx, &testgrpc.Empty{}); err != nil { + t.Fatalf("EmptyCall failed: %v", err) + } + + if !proxyCalled { + t.Fatalf("Proxy not connected") + } +} + +// Tests the scenario where grpc.NewClient is used with grpc.WithNoProxy() set, +// explicitly disabling proxy usage. The test verifies that the client does not +// dial the proxy but directly connects to the backend server. It also checks +// that the proxy resolution function is not called and that the proxy server +// never receives a connection request. +func (s) TestNewClientWithNoProxy(t *testing.T) { + backend := startBackendServer(t) + unresolvedTargetURI := fmt.Sprintf("localhost:%d", testutils.ParsePort(t, backend.Address)) + reqCheck := func(_ *http.Request) { t.Error("proxy server should not have received a Connect request") } + pServer := proxyserver.New(t, reqCheck, false) + + // Overwrite the function in the test and restore them in defer. + hpfe := func(req *http.Request) (*url.URL, error) { + if req.URL.Host == unresolvedTargetURI { + return &url.URL{ + Scheme: "https", + Host: pServer.Addr, + }, nil + } + t.Errorf("Unexpected request host to proxy: %s want %s", req.URL.Host, unresolvedTargetURI) + return nil, nil + } + orighpfe := delegatingresolver.HTTPSProxyFromEnvironment + delegatingresolver.HTTPSProxyFromEnvironment = hpfe + defer func() { delegatingresolver.HTTPSProxyFromEnvironment = orighpfe }() + + dopts := []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithNoProxy(), // Disable proxy. + } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + conn, err := grpc.NewClient(unresolvedTargetURI, dopts...) + if err != nil { + t.Fatalf("grpc.NewClient(%s) failed: %v", unresolvedTargetURI, err) + } + defer conn.Close() + + // Create a test service client and make an RPC call. + client := testgrpc.NewTestServiceClient(conn) + if _, err := client.EmptyCall(ctx, &testgrpc.Empty{}); err != nil { + t.Fatalf("EmptyCall() failed: %v", err) + } +} + +// Tests the scenario where grpc.NewClient is used with grpc.WithContextDialer() +// set. The test verifies that the client bypasses proxy dialing and uses the +// custom dialer instead. It ensures that the proxy server is never dialed, the +// proxy resolution function is not triggered, and the custom dialer is invoked +// as expected. +func (s) TestNewClientWithContextDialer(t *testing.T) { + backend := startBackendServer(t) + unresolvedTargetURI := fmt.Sprintf("localhost:%d", testutils.ParsePort(t, backend.Address)) + reqCheck := func(_ *http.Request) { t.Error("proxy server should not have received a Connect request") } + pServer := proxyserver.New(t, reqCheck, false) + + // Overwrite the function in the test and restore them in defer. + hpfe := func(req *http.Request) (*url.URL, error) { + if req.URL.Host == unresolvedTargetURI { + return &url.URL{ + Scheme: "https", + Host: pServer.Addr, + }, nil + } + t.Errorf("Unexpected request host to proxy: %s want %s", req.URL.Host, unresolvedTargetURI) + return nil, nil + } + orighpfe := delegatingresolver.HTTPSProxyFromEnvironment + delegatingresolver.HTTPSProxyFromEnvironment = hpfe + defer func() { delegatingresolver.HTTPSProxyFromEnvironment = orighpfe }() + + // Create a custom dialer that directly dials the backend. + customDialer := func(_ context.Context, unresolvedTargetURI string) (net.Conn, error) { + return net.Dial("tcp", unresolvedTargetURI) + } + + dopts := []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithContextDialer(customDialer), // Use a custom dialer. + } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + conn, err := grpc.NewClient(unresolvedTargetURI, dopts...) + if err != nil { + t.Fatalf("grpc.NewClient(%s) failed: %v", unresolvedTargetURI, err) + } + defer conn.Close() + + client := testgrpc.NewTestServiceClient(conn) + if _, err := client.EmptyCall(ctx, &testgrpc.Empty{}); err != nil { + t.Fatalf("EmptyCall() failed: %v", err) + } +} + +// Tests the scenario where grpc.NewClient is used with the default DNS resolver +// for targetURI and a proxy. The test verifies that the client connects to the +// proxy server, sends the unresolved target URI in the HTTP CONNECT request, +// and successfully connects to the backend. Additionally, it checks that the +// correct user information is included in the Proxy-Authorization header of +// the CONNECT request. The test also ensures that target resolution does not +// happen on the client. +func (s) TestBasicAuthInNewClientWithProxy(t *testing.T) { + unresolvedTargetURI := "example.test" + const ( + user = "notAUser" + password = "notAPassword" + ) + proxyCalled := false + reqCheck := func(req *http.Request) { + proxyCalled = true + if got, want := req.URL.Host, "example.test"; got != want { + t.Errorf(" Unexpected request host: %s , want = %s ", got, want) + } + wantProxyAuthStr := "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+password)) + if got := req.Header.Get("Proxy-Authorization"); got != wantProxyAuthStr { + gotDecoded, err := base64.StdEncoding.DecodeString(got) + if err != nil { + t.Errorf("failed to decode Proxy-Authorization header: %v", err) + } + wantDecoded, _ := base64.StdEncoding.DecodeString(wantProxyAuthStr) + t.Errorf("unexpected auth %q (%q), want %q (%q)", got, gotDecoded, wantProxyAuthStr, wantDecoded) + } + } + pServer := proxyserver.New(t, reqCheck, false) + + t.Setenv("HTTPS_PROXY", user+":"+password+"@"+pServer.Addr) + + // Use the httpproxy package functions instead of `http.ProxyFromEnvironment` + // because the latter reads proxy-related environment variables only once at + // initialization. This behavior causes issues when running test multiple + // times, as changes to environment variables during tests would be ignored. + // By using `httpproxy.FromEnvironment()`, we ensure proxy settings are read dynamically. + origHTTPSProxyFromEnvironment := delegatingresolver.HTTPSProxyFromEnvironment + delegatingresolver.HTTPSProxyFromEnvironment = func(req *http.Request) (*url.URL, error) { + return httpproxy.FromEnvironment().ProxyFunc()(req.URL) + } + defer func() { + delegatingresolver.HTTPSProxyFromEnvironment = origHTTPSProxyFromEnvironment + }() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + conn, err := grpc.NewClient(unresolvedTargetURI, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("grpc.NewClient(%s) failed: %v", unresolvedTargetURI, err) + } + defer conn.Close() + + // Send an empty RPC to the backend through the proxy. + client := testgrpc.NewTestServiceClient(conn) + client.EmptyCall(ctx, &testgrpc.Empty{}) + + if !proxyCalled { + t.Fatalf("Proxy not connected") + } +} diff --git a/internal/transport/proxy_test.go b/internal/transport/proxy_test.go index 41f0918d1c90..442d2cb12e6d 100644 --- a/internal/transport/proxy_test.go +++ b/internal/transport/proxy_test.go @@ -22,126 +22,39 @@ package transport import ( - "bufio" - "bytes" "context" - "encoding/base64" - "fmt" - "io" "net" "net/http" - "net/url" + "net/netip" "testing" "time" -) -const ( - envTestAddr = "1.2.3.4:8080" - envProxyAddr = "2.3.4.5:7687" + "google.golang.org/grpc/internal/proxyattributes" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/testutils/proxyserver" + "google.golang.org/grpc/resolver" ) -// overwriteAndRestore overwrite function httpProxyFromEnvironment and -// returns a function to restore the default values. -func overwrite(hpfe func(req *http.Request) (*url.URL, error)) func() { - backHPFE := httpProxyFromEnvironment - httpProxyFromEnvironment = hpfe - return func() { - httpProxyFromEnvironment = backHPFE - } -} - -type proxyServer struct { - t *testing.T - lis net.Listener - in net.Conn - out net.Conn - - requestCheck func(*http.Request) error -} - -func (p *proxyServer) run(waitForServerHello bool) { - in, err := p.lis.Accept() - if err != nil { - return - } - p.in = in - - req, err := http.ReadRequest(bufio.NewReader(in)) - if err != nil { - p.t.Errorf("failed to read CONNECT req: %v", err) - return - } - if err := p.requestCheck(req); err != nil { - resp := http.Response{StatusCode: http.StatusMethodNotAllowed} - resp.Write(p.in) - p.in.Close() - p.t.Errorf("get wrong CONNECT req: %+v, error: %v", req, err) - return - } - - out, err := net.Dial("tcp", req.URL.Host) +func (s) TestHTTPConnectWithServerHello(t *testing.T) { + serverMessage := []byte("server-hello") + blis, err := testutils.LocalTCPListener() if err != nil { - p.t.Errorf("failed to dial to server: %v", err) - return + t.Fatalf("failed to listen: %v", err) } - out.SetDeadline(time.Now().Add(defaultTestTimeout)) - resp := http.Response{StatusCode: http.StatusOK, Proto: "HTTP/1.0"} - var buf bytes.Buffer - resp.Write(&buf) - if waitForServerHello { - // Batch the first message from the server with the http connect - // response. This is done to test the cases in which the grpc client has - // the response to the connect request and proxied packets from the - // destination server when it reads the transport. - b := make([]byte, 50) - bytesRead, err := out.Read(b) + reqCheck := func(req *http.Request) { + if req.Method != http.MethodConnect { + t.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect) + } + host, _, err := net.SplitHostPort(req.URL.Host) if err != nil { - p.t.Errorf("Got error while reading server hello: %v", err) - in.Close() - out.Close() - return + t.Error(err) + } + _, err = netip.ParseAddr(host) + if err != nil { + t.Error(err) } - buf.Write(b[0:bytesRead]) - } - p.in.Write(buf.Bytes()) - p.out = out - go io.Copy(p.in, p.out) - go io.Copy(p.out, p.in) -} - -func (p *proxyServer) stop() { - p.lis.Close() - if p.in != nil { - p.in.Close() - } - if p.out != nil { - p.out.Close() - } -} - -type testArgs struct { - proxyURLModify func(*url.URL) *url.URL - proxyReqCheck func(*http.Request) error - serverMessage []byte -} - -func testHTTPConnect(t *testing.T, args testArgs) { - plis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("failed to listen: %v", err) - } - p := &proxyServer{ - t: t, - lis: plis, - requestCheck: args.proxyReqCheck, - } - go p.run(len(args.serverMessage) > 0) - defer p.stop() - - blis, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("failed to listen: %v", err) } + pServer := proxyserver.New(t, reqCheck, true) msg := []byte{4, 3, 5, 2} recvBuf := make([]byte, len(msg)) @@ -153,21 +66,15 @@ func testHTTPConnect(t *testing.T, args testArgs) { return } defer in.Close() - in.Write(args.serverMessage) + in.Write(serverMessage) in.Read(recvBuf) done <- nil }() - // Overwrite the function in the test and restore them in defer. - hpfe := func(*http.Request) (*url.URL, error) { - return args.proxyURLModify(&url.URL{Host: plis.Addr().String()}), nil - } - defer overwrite(hpfe)() - // Dial to proxy server. ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - c, err := proxyDial(ctx, blis.Addr().String(), "test") + c, err := proxyDial(ctx, resolver.Address{Addr: pServer.Addr}, "test", proxyattributes.Options{ConnectAddr: blis.Addr().String()}) if err != nil { t.Fatalf("HTTP connect Dial failed: %v", err) } @@ -185,94 +92,14 @@ func testHTTPConnect(t *testing.T, args testArgs) { t.Fatalf("Received msg: %v, want %v", recvBuf, msg) } - if len(args.serverMessage) > 0 { - gotServerMessage := make([]byte, len(args.serverMessage)) + if len(serverMessage) > 0 { + gotServerMessage := make([]byte, len(serverMessage)) if _, err := c.Read(gotServerMessage); err != nil { t.Errorf("Got error while reading message from server: %v", err) return } - if string(gotServerMessage) != string(args.serverMessage) { - t.Errorf("Message from server: %v, want %v", gotServerMessage, args.serverMessage) - } - } -} - -func (s) TestHTTPConnect(t *testing.T) { - args := testArgs{ - proxyURLModify: func(in *url.URL) *url.URL { - return in - }, - proxyReqCheck: func(req *http.Request) error { - if req.Method != http.MethodConnect { - return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect) - } - return nil - }, - } - testHTTPConnect(t, args) -} - -func (s) TestHTTPConnectWithServerHello(t *testing.T) { - args := testArgs{ - proxyURLModify: func(in *url.URL) *url.URL { - return in - }, - proxyReqCheck: func(req *http.Request) error { - if req.Method != http.MethodConnect { - return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect) - } - return nil - }, - serverMessage: []byte("server-hello"), - } - testHTTPConnect(t, args) -} - -func (s) TestHTTPConnectBasicAuth(t *testing.T) { - const ( - user = "notAUser" - password = "notAPassword" - ) - args := testArgs{ - proxyURLModify: func(in *url.URL) *url.URL { - in.User = url.UserPassword(user, password) - return in - }, - proxyReqCheck: func(req *http.Request) error { - if req.Method != http.MethodConnect { - return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect) - } - wantProxyAuthStr := "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+password)) - if got := req.Header.Get(proxyAuthHeaderKey); got != wantProxyAuthStr { - gotDecoded, _ := base64.StdEncoding.DecodeString(got) - wantDecoded, _ := base64.StdEncoding.DecodeString(wantProxyAuthStr) - return fmt.Errorf("unexpected auth %q (%q), want %q (%q)", got, gotDecoded, wantProxyAuthStr, wantDecoded) - } - return nil - }, - } - testHTTPConnect(t, args) -} - -func (s) TestMapAddressEnv(t *testing.T) { - // Overwrite the function in the test and restore them in defer. - hpfe := func(req *http.Request) (*url.URL, error) { - if req.URL.Host == envTestAddr { - return &url.URL{ - Scheme: "https", - Host: envProxyAddr, - }, nil + if string(gotServerMessage) != string(serverMessage) { + t.Errorf("Message from server: %v, want %v", gotServerMessage, serverMessage) } - return nil, nil - } - defer overwrite(hpfe)() - - // envTestAddr should be handled by ProxyFromEnvironment. - got, err := mapAddress(envTestAddr) - if err != nil { - t.Error(err) - } - if got.Host != envProxyAddr { - t.Errorf("want %v, got %v", envProxyAddr, got) } } diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 2859b87755f0..af4a4aeab145 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -502,8 +502,6 @@ type ConnectOptions struct { ChannelzParent *channelz.SubChannel // MaxHeaderListSize sets the max (uncompressed) size of header list that is prepared to be received. MaxHeaderListSize *uint32 - // UseProxy specifies if a proxy should be used. - UseProxy bool // The mem.BufferPool to use when reading/writing to the wire. BufferPool mem.BufferPool } diff --git a/resolver_wrapper.go b/resolver_wrapper.go index 23bb3fb25824..9f5e12fc03c9 100644 --- a/resolver_wrapper.go +++ b/resolver_wrapper.go @@ -26,6 +26,7 @@ import ( "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/pretty" + "google.golang.org/grpc/internal/resolver/delegatingresolver" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" ) @@ -78,7 +79,16 @@ func (ccr *ccResolverWrapper) start() error { Authority: ccr.cc.authority, } var err error - ccr.resolver, err = ccr.cc.resolverBuilder.Build(ccr.cc.parsedTarget, ccr, opts) + // The delegating resolver is used unless: + // - A custom dialer is provided via WithContextDialer dialoption or + // - Proxy usage is disabled through WithNoProxy dialoption. + // In these cases, the resolver is built based on the scheme of target, + // using the appropriate resolver builder. + if ccr.cc.dopts.copts.Dialer != nil || !ccr.cc.dopts.useProxy { + ccr.resolver, err = ccr.cc.resolverBuilder.Build(ccr.cc.parsedTarget, ccr, opts) + } else { + ccr.resolver, err = delegatingresolver.New(ccr.cc.parsedTarget, ccr, opts, ccr.cc.resolverBuilder, ccr.cc.dopts.enableLocalDNSResolution) + } errCh <- err }) return <-errCh