diff --git a/xds/internal/xdsclient/authority.go b/xds/internal/xdsclient/authority.go index ba0b080c92bc..6ad61dae4ae4 100644 --- a/xds/internal/xdsclient/authority.go +++ b/xds/internal/xdsclient/authority.go @@ -448,10 +448,6 @@ func (a *authority) close() { a.resourcesMu.Lock() a.closed = true a.resourcesMu.Unlock() - - for _, cleanup := range a.serverCfg.Cleanups { - cleanup() - } } func (a *authority) watchResource(rType xdsresource.Type, resourceName string, watcher xdsresource.ResourceWatcher) func() { diff --git a/xds/internal/xdsclient/bootstrap/bootstrap.go b/xds/internal/xdsclient/bootstrap/bootstrap.go index 31a7a69f93cf..0736a06d73a4 100644 --- a/xds/internal/xdsclient/bootstrap/bootstrap.go +++ b/xds/internal/xdsclient/bootstrap/bootstrap.go @@ -167,7 +167,7 @@ type ServerConfig struct { IgnoreResourceDeletion bool // Cleanups are called when the xDS client for this server is closed. Allows - // cleaning up resources created specifically for the xDS client. + // cleaning up resources created specifically for this ServerConfig. Cleanups []func() } diff --git a/xds/internal/xdsclient/bootstrap/bootstrap_test.go b/xds/internal/xdsclient/bootstrap/bootstrap_test.go index 7975c66667b9..a1138e2363d5 100644 --- a/xds/internal/xdsclient/bootstrap/bootstrap_test.go +++ b/xds/internal/xdsclient/bootstrap/bootstrap_test.go @@ -1008,49 +1008,39 @@ func TestServerConfigMarshalAndUnmarshal(t *testing.T) { } func TestDefaultBundles(t *testing.T) { - if c := bootstrap.GetCredentials("google_default"); c == nil { - t.Errorf(`bootstrap.GetCredentials("google_default") credential is nil, want non-nil`) - } - - if c := bootstrap.GetCredentials("insecure"); c == nil { - t.Errorf(`bootstrap.GetCredentials("insecure") credential is nil, want non-nil`) - } + tests := []string{"google_default", "insecure", "tls"} - if c := bootstrap.GetCredentials("tls"); c == nil { - t.Errorf(`bootstrap.GetCredentials("tls") credential is nil, want non-nil`) + for _, typename := range tests { + t.Run(typename, func(t *testing.T) { + if c := bootstrap.GetCredentials(typename); c == nil { + t.Errorf(`bootstrap.GetCredentials(%s) credential is nil, want non-nil`, typename) + } + }) } } func TestCredsBuilders(t *testing.T) { - b := &googleDefaultCredsBuilder{} - if _, stop, err := b.Build(nil); err != nil { - t.Errorf("googleDefaultCredsBuilder.Build failed: %v", err) - } else { - stop() - } - if got, want := b.Name(), "google_default"; got != want { - t.Errorf("googleDefaultCredsBuilder.Name = %v, want %v", got, want) - } - - i := &insecureCredsBuilder{} - if _, stop, err := i.Build(nil); err != nil { - t.Errorf("insecureCredsBuilder.Build failed: %v", err) - } else { - stop() + tests := []struct { + typename string + builder bootstrap.Credentials + }{ + {"google_default", &googleDefaultCredsBuilder{}}, + {"insecure", &insecureCredsBuilder{}}, + {"tls", &tlsCredsBuilder{}}, } - if got, want := i.Name(), "insecure"; got != want { - t.Errorf("insecureCredsBuilder.Name = %v, want %v", got, want) - } + for _, test := range tests { + t.Run(test.typename, func(t *testing.T) { + if got, want := test.builder.Name(), test.typename; got != want { + t.Errorf("%T.Name = %v, want %v", test.builder, got, want) + } - tcb := &tlsCredsBuilder{} - if _, stop, err := tcb.Build(nil); err != nil { - t.Errorf("tlsCredsBuilder.Build failed: %v", err) - } else { - stop() - } - if got, want := tcb.Name(), "tls"; got != want { - t.Errorf("tlsCredsBuilder.Name = %v, want %v", got, want) + _, stop, err := test.builder.Build(nil) + if err != nil { + t.Fatalf("%T.Build failed: %v", test.builder, err) + } + stop() + }) } } @@ -1061,9 +1051,10 @@ func TestTlsCredsBuilder(t *testing.T) { t.Fatalf("tls.Build() failed with error %s when expected to succeed", err) } stop() + if _, stop, err := tls.Build(json.RawMessage(`{"ca_certificate_file":"/ca_certificates.pem","refresh_interval": "asdf"}`)); err == nil { t.Errorf("tls.Build() succeeded with an invalid refresh interval, when expected to fail") stop() } - // more tests for config validity are defined in tlscreds subpackage. + // package internal/xdsclient/tlscreds has tests for config validity. } diff --git a/xds/internal/xdsclient/clientimpl.go b/xds/internal/xdsclient/clientimpl.go index 2c05ea66f5f9..1088b60301cb 100644 --- a/xds/internal/xdsclient/clientimpl.go +++ b/xds/internal/xdsclient/clientimpl.go @@ -85,5 +85,17 @@ func (c *clientImpl) close() { c.authorityMu.Unlock() c.serializerClose() + for _, f := range c.config.XDSServer.Cleanups { + f() + } + for _, a := range c.config.Authorities { + if a.XDSServer == nil { + // The server for this authority is the top-level one, cleaned up above. + continue + } + for _, f := range a.XDSServer.Cleanups { + f() + } + } c.logger.Infof("Shutdown") } diff --git a/xds/internal/xdsclient/singleton_test.go b/xds/internal/xdsclient/singleton_test.go index 1875ea118d09..bb3f318f1989 100644 --- a/xds/internal/xdsclient/singleton_test.go +++ b/xds/internal/xdsclient/singleton_test.go @@ -20,6 +20,7 @@ package xdsclient import ( "context" + "encoding/json" "testing" "github.com/google/uuid" @@ -36,6 +37,7 @@ func (s) TestClientNewSingleton(t *testing.T) { cleanup, err := bootstrap.CreateFile(bootstrap.Options{ NodeID: nodeID, ServerURI: "non-existent-server-address", + CertificateProviders: map[string]json.RawMessage{}, }) if err != nil { t.Fatal(err) diff --git a/xds/internal/xdsclient/tlscreds/bundle.go b/xds/internal/xdsclient/tlscreds/bundle.go index c4e977c9d76d..02da3dbf3496 100644 --- a/xds/internal/xdsclient/tlscreds/bundle.go +++ b/xds/internal/xdsclient/tlscreds/bundle.go @@ -31,6 +31,7 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/tls/certprovider" "google.golang.org/grpc/credentials/tls/certprovider/pemfile" + "google.golang.org/grpc/internal/grpcsync" ) // bundle is an implementation of credentials.Bundle which implements mTLS @@ -41,7 +42,9 @@ type bundle struct { // NewBundle returns a credentials.Bundle which implements mTLS Credentials in xDS // Bootstrap File. It delegates certificate loading to a file_watcher provider -// if either client certificates or server root CA is specified. +// if either client certificates or server root CA is specified. The second +// return value is a close func that should be called when the caller no longer +// needs this bundle. // See gRFC A65: github.com/grpc/proposal/blob/master/A65-xds-mtls-creds-in-bootstrap.md func NewBundle(jd json.RawMessage) (credentials.Bundle, func(), error) { cfg := &struct { @@ -78,7 +81,7 @@ func NewBundle(jd json.RawMessage) (credentials.Bundle, func(), error) { } return &bundle{ transportCredentials: &reloadingCreds{provider: provider}, - }, func() { provider.Close() }, nil + }, grpcsync.OnceFunc(func() { provider.Close() }), nil } func (t *bundle) TransportCredentials() credentials.TransportCredentials { @@ -97,15 +100,6 @@ func (t *bundle) NewWithMode(string) (credentials.Bundle, error) { return nil, fmt.Errorf("xDS TLS credentials only support one mode") } -// Close releases the underlying provider. Note that credentials.Bundle are -// not closeable, so users of this type must use a type assertion to call Close. -func (t *bundle) Close() { - cred, ok := t.transportCredentials.(*reloadingCreds) - if ok { - cred.provider.Close() - } -} - // reloadingCreds is a credentials.TransportCredentials for client // side mTLS that reloads the server root CA certificate and the client // certificates from the provider on every client handshake. This is necessary diff --git a/xds/internal/xdsclient/tlscreds/bundle_ext_test.go b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go index 02eedf78dee3..bda7319d83ce 100644 --- a/xds/internal/xdsclient/tlscreds/bundle_ext_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go @@ -106,11 +106,11 @@ func (s) TestValidTlsBuilder(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { msg := json.RawMessage(test.jd) - if _, stop, err := tlscreds.NewBundle(msg); err != nil { - t.Errorf("NewBundle(%s) returned error %s when expected to succeed", test.jd, err) - } else { - stop() + _, stop, err := tlscreds.NewBundle(msg) + if err != nil { + t.Fatalf("NewBundle(%s) returned error %s when expected to succeed", test.jd, err) } + stop() }) } } @@ -133,11 +133,12 @@ func (s) TestInvalidTlsBuilder(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { msg := json.RawMessage(test.jd) - if _, stop, err := tlscreds.NewBundle(msg); err == nil || !strings.HasPrefix(err.Error(), test.wantErrPrefix) { - t.Errorf("NewBundle(%s): got error %s, want an error with prefix %s", msg, err, test.wantErrPrefix) - if err == nil { + _, stop, err := tlscreds.NewBundle(msg) + if err == nil || !strings.HasPrefix(err.Error(), test.wantErrPrefix) { + if stop != nil { stop() } + t.Fatalf("NewBundle(%s): got error %s, want an error with prefix %s", msg, err, test.wantErrPrefix) } }) } diff --git a/xds/internal/xdsclient/tlscreds/bundle_test.go b/xds/internal/xdsclient/tlscreds/bundle_test.go index 8bc3f55b4c13..ad50508aeb94 100644 --- a/xds/internal/xdsclient/tlscreds/bundle_test.go +++ b/xds/internal/xdsclient/tlscreds/bundle_test.go @@ -46,7 +46,7 @@ func Test(t *testing.T) { type failingProvider struct{} -func (f failingProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) { +func (f failingProvider) KeyMaterial(context.Context) (*certprovider.KeyMaterial, error) { return nil, errors.New("test error") }