diff --git a/clientconn.go b/clientconn.go index 2359f94b8a45..77f2d2cceaaa 100644 --- a/clientconn.go +++ b/clientconn.go @@ -152,6 +152,16 @@ func NewClient(target string, opts ...DialOption) (conn *ClientConn, err error) for _, opt := range opts { opt.apply(&cc.dopts) } + + // Determine the resolver to use. + if err := cc.initParsedTargetAndResolverBuilder(); err != nil { + return nil, err + } + + for _, opt := range globalPerTargetDialOptions { + opt.DialOptionForTarget(cc.parsedTarget.URL).apply(&cc.dopts) + } + chainUnaryClientInterceptors(cc) chainStreamClientInterceptors(cc) @@ -168,25 +178,16 @@ func NewClient(target string, opts ...DialOption) (conn *ClientConn, err error) } cc.mkp = cc.dopts.copts.KeepaliveParams - // Register ClientConn with channelz. - cc.channelzRegistration(target) - - // TODO: Ideally it should be impossible to error from this function after - // channelz registration. This will require removing some channelz logs - // from the following functions that can error. Errors can be returned to - // the user, and successful logs can be emitted here, after the checks have - // passed and channelz is subsequently registered. - - // Determine the resolver to use. - if err := cc.parseTargetAndFindResolver(); err != nil { - channelz.RemoveEntry(cc.channelz.ID) - return nil, err - } - if err = cc.determineAuthority(); err != nil { - channelz.RemoveEntry(cc.channelz.ID) + if err = cc.initAuthority(); err != nil { return nil, err } + // Register ClientConn with channelz. Note that this is only done after + // channel creation cannot fail. + cc.channelzRegistration(target) + channelz.Infof(logger, cc.channelz, "parsed dial target is: %#v", cc.parsedTarget) + channelz.Infof(logger, cc.channelz, "Channel authority set to %q", cc.authority) + cc.csMgr = newConnectivityStateManager(cc.ctx, cc.channelz) cc.pickerWrapper = newPickerWrapper(cc.dopts.copts.StatsHandlers) @@ -587,11 +588,11 @@ type ClientConn struct { // The following are initialized at dial time, and are read-only after that. target string // User's dial target. - parsedTarget resolver.Target // See parseTargetAndFindResolver(). - authority string // See determineAuthority(). + parsedTarget resolver.Target // See initParsedTargetAndResolverBuilder(). + authority string // See initAuthority(). dopts dialOptions // Default and user specified dial options. channelz *channelz.Channel // Channelz object. - resolverBuilder resolver.Builder // See parseTargetAndFindResolver(). + resolverBuilder resolver.Builder // See initParsedTargetAndResolverBuilder(). idlenessMgr *idle.Manager // The following provide their own synchronization, and therefore don't @@ -1673,22 +1674,19 @@ func (cc *ClientConn) connectionError() error { return cc.lastConnectionError } -// parseTargetAndFindResolver parses the user's dial target and stores the -// parsed target in `cc.parsedTarget`. +// initParsedTargetAndResolverBuilder parses the user's dial target and stores +// the parsed target in `cc.parsedTarget`. // // The resolver to use is determined based on the scheme in the parsed target // and the same is stored in `cc.resolverBuilder`. // // Doesn't grab cc.mu as this method is expected to be called only at Dial time. -func (cc *ClientConn) parseTargetAndFindResolver() error { - channelz.Infof(logger, cc.channelz, "original dial target is: %q", cc.target) +func (cc *ClientConn) initParsedTargetAndResolverBuilder() error { + logger.Infof("original dial target is: %q", cc.target) var rb resolver.Builder parsedTarget, err := parseTarget(cc.target) - if err != nil { - channelz.Infof(logger, cc.channelz, "dial target %q parse failed: %v", cc.target, err) - } else { - channelz.Infof(logger, cc.channelz, "parsed dial target is: %#v", parsedTarget) + if err == nil { rb = cc.getResolver(parsedTarget.URL.Scheme) if rb != nil { cc.parsedTarget = parsedTarget @@ -1707,15 +1705,12 @@ func (cc *ClientConn) parseTargetAndFindResolver() error { defScheme = resolver.GetDefaultScheme() } - channelz.Infof(logger, cc.channelz, "fallback to scheme %q", defScheme) canonicalTarget := defScheme + ":///" + cc.target parsedTarget, err = parseTarget(canonicalTarget) if err != nil { - channelz.Infof(logger, cc.channelz, "dial target %q parse failed: %v", canonicalTarget, err) return err } - channelz.Infof(logger, cc.channelz, "parsed dial target is: %+v", parsedTarget) rb = cc.getResolver(parsedTarget.URL.Scheme) if rb == nil { return fmt.Errorf("could not get resolver for default scheme: %q", parsedTarget.URL.Scheme) @@ -1805,7 +1800,7 @@ func encodeAuthority(authority string) string { // credentials do not match the authority configured through the dial option. // // Doesn't grab cc.mu as this method is expected to be called only at Dial time. -func (cc *ClientConn) determineAuthority() error { +func (cc *ClientConn) initAuthority() error { dopts := cc.dopts // Historically, we had two options for users to specify the serverName or // authority for a channel. One was through the transport credentials @@ -1838,6 +1833,5 @@ func (cc *ClientConn) determineAuthority() error { } else { cc.authority = encodeAuthority(endpoint) } - channelz.Infof(logger, cc.channelz, "Channel authority set to %q", cc.authority) return nil } diff --git a/default_dial_option_server_option_test.go b/default_dial_option_server_option_test.go index c4e7143c449f..a9544d6156c8 100644 --- a/default_dial_option_server_option_test.go +++ b/default_dial_option_server_option_test.go @@ -20,6 +20,7 @@ package grpc import ( "fmt" + "net/url" "strings" "testing" @@ -40,6 +41,7 @@ func (s) TestAddGlobalDialOptions(t *testing.T) { // Set and check the DialOptions opts := []DialOption{WithTransportCredentials(insecure.NewCredentials()), WithTransportCredentials(insecure.NewCredentials()), WithTransportCredentials(insecure.NewCredentials())} internal.AddGlobalDialOptions.(func(opt ...DialOption))(opts...) + defer internal.ClearGlobalDialOptions() for i, opt := range opts { if globalDialOptions[i] != opt { t.Fatalf("Unexpected global dial option at index %d: %v != %v", i, globalDialOptions[i], opt) @@ -64,6 +66,7 @@ func (s) TestAddGlobalDialOptions(t *testing.T) { func (s) TestDisableGlobalOptions(t *testing.T) { // Set transport credentials as a global option. internal.AddGlobalDialOptions.(func(opt ...DialOption))(WithTransportCredentials(insecure.NewCredentials())) + defer internal.ClearGlobalDialOptions() // Dial with the disable global options dial option. This dial should fail // due to the global dial options with credentials not being picked up due // to global options being disabled. @@ -71,7 +74,34 @@ func (s) TestDisableGlobalOptions(t *testing.T) { if _, err := Dial("fake", internal.DisableGlobalDialOptions.(func() DialOption)()); !strings.Contains(fmt.Sprint(err), noTSecStr) { t.Fatalf("Dialing received unexpected error: %v, want error containing \"%v\"", err, noTSecStr) } - internal.ClearGlobalDialOptions() +} + +type testPerTargetDialOption struct{} + +func (do *testPerTargetDialOption) DialOptionForTarget(parsedTarget url.URL) DialOption { + if parsedTarget.Scheme == "passthrough" { + return WithTransportCredentials(insecure.NewCredentials()) // credentials provided, should pass NewClient. + } + return EmptyDialOption{} // no credentials, should fail NewClient +} + +// TestGlobalPerTargetDialOption configures a global per target dial option that +// produces transport credentials for channels using "passthrough" scheme. +// Channels that use the passthrough scheme should be successfully created due +// to picking up transport credentials, whereas other channels should fail at +// creation due to not having transport credentials. +func (s) TestGlobalPerTargetDialOption(t *testing.T) { + internal.AddGlobalPerTargetDialOptions.(func(opt any))(&testPerTargetDialOption{}) + defer internal.ClearGlobalPerTargetDialOptions() + noTSecStr := "no transport security set" + if _, err := NewClient("dns:///fake"); !strings.Contains(fmt.Sprint(err), noTSecStr) { + t.Fatalf("Dialing received unexpected error: %v, want error containing \"%v\"", err, noTSecStr) + } + cc, err := NewClient("passthrough:///nice") + if err != nil { + t.Fatalf("Dialing with insecure credentials failed: %v", err) + } + cc.Close() } func (s) TestAddGlobalServerOptions(t *testing.T) { @@ -79,6 +109,7 @@ func (s) TestAddGlobalServerOptions(t *testing.T) { // Set and check the ServerOptions opts := []ServerOption{Creds(insecure.NewCredentials()), MaxRecvMsgSize(maxRecvSize)} internal.AddGlobalServerOptions.(func(opt ...ServerOption))(opts...) + defer internal.ClearGlobalServerOptions() for i, opt := range opts { if globalServerOptions[i] != opt { t.Fatalf("Unexpected global server option at index %d: %v != %v", i, globalServerOptions[i], opt) diff --git a/dialoptions.go b/dialoptions.go index 00273702b694..1400c0e7898d 100644 --- a/dialoptions.go +++ b/dialoptions.go @@ -21,6 +21,7 @@ package grpc import ( "context" "net" + "net/url" "time" "google.golang.org/grpc/backoff" @@ -43,6 +44,14 @@ func init() { internal.ClearGlobalDialOptions = func() { globalDialOptions = nil } + internal.AddGlobalPerTargetDialOptions = func(opt any) { + if ptdo, ok := opt.(perTargetDialOption); ok { + globalPerTargetDialOptions = append(globalPerTargetDialOptions, ptdo) + } + } + internal.ClearGlobalPerTargetDialOptions = func() { + globalPerTargetDialOptions = nil + } internal.WithBinaryLogger = withBinaryLogger internal.JoinDialOptions = newJoinDialOption internal.DisableGlobalDialOptions = newDisableGlobalDialOptions @@ -89,6 +98,19 @@ type DialOption interface { var globalDialOptions []DialOption +// perTargetDialOption takes a parsed target and returns a dial option to apply. +// +// This gets called after NewClient() parses the target, and allows per target +// configuration set through a returned DialOption. The DialOption will not take +// effect if specifies a resolver builder, as that Dial Option is factored in +// while parsing target. +type perTargetDialOption interface { + // DialOption returns a Dial Option to apply. + DialOptionForTarget(parsedTarget url.URL) DialOption +} + +var globalPerTargetDialOptions []perTargetDialOption + // EmptyDialOption does not alter the dial configuration. It can be embedded in // another structure to build custom dial options. // diff --git a/internal/internal.go b/internal/internal.go index 48d24bdb4e69..146e5febf1f6 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -106,6 +106,14 @@ var ( // This is used in the 1.0 release of gcp/observability, and thus must not be // deleted or changed. ClearGlobalDialOptions func() + + // AddGlobalPerTargetDialOptions adds a PerTargetDialOption that will be + // configured for newly created ClientConns. + AddGlobalPerTargetDialOptions any // func (opt any) + // ClearGlobalPerTargetDialOptions clears the slice of global late apply + // dial options. + ClearGlobalPerTargetDialOptions func() + // JoinDialOptions combines the dial options passed as arguments into a // single dial option. JoinDialOptions any // func(...grpc.DialOption) grpc.DialOption @@ -126,7 +134,8 @@ var ( // deleted or changed. BinaryLogger any // func(binarylog.Logger) grpc.ServerOption - // SubscribeToConnectivityStateChanges adds a grpcsync.Subscriber to a provided grpc.ClientConn + // SubscribeToConnectivityStateChanges adds a grpcsync.Subscriber to a + // provided grpc.ClientConn. SubscribeToConnectivityStateChanges any // func(*grpc.ClientConn, grpcsync.Subscriber) // NewXDSResolverWithConfigForTesting creates a new xds resolver builder using @@ -195,14 +204,17 @@ var ( // resource name. TriggerXDSResourceNameNotFoundClient any // func(string, string) error - // FromOutgoingContextRaw returns the un-merged, intermediary contents of metadata.rawMD. + // FromOutgoingContextRaw returns the un-merged, intermediary contents of + // metadata.rawMD. FromOutgoingContextRaw any // func(context.Context) (metadata.MD, [][]string, bool) - // UserSetDefaultScheme is set to true if the user has overridden the default resolver scheme. + // UserSetDefaultScheme is set to true if the user has overridden the + // default resolver scheme. UserSetDefaultScheme bool = false ) -// HealthChecker defines the signature of the client-side LB channel health checking function. +// HealthChecker defines the signature of the client-side LB channel health +// checking function. // // The implementation is expected to create a health checking RPC stream by // calling newStream(), watch for the health status of serviceName, and report