diff --git a/credentials/tls/certprovider/pemfile/builder.go b/credentials/tls/certprovider/pemfile/builder.go index 8d8e2d4a0f5a..8c15baeb59f6 100644 --- a/credentials/tls/certprovider/pemfile/builder.go +++ b/credentials/tls/certprovider/pemfile/builder.go @@ -29,7 +29,7 @@ import ( ) const ( - pluginName = "file_watcher" + PluginName = "file_watcher" defaultRefreshInterval = 10 * time.Minute ) @@ -48,13 +48,13 @@ func (p *pluginBuilder) ParseConfig(c any) (*certprovider.BuildableConfig, error if err != nil { return nil, err } - return certprovider.NewBuildableConfig(pluginName, opts.canonical(), func(certprovider.BuildOptions) certprovider.Provider { + return certprovider.NewBuildableConfig(PluginName, opts.canonical(), func(certprovider.BuildOptions) certprovider.Provider { return newProvider(opts) }), nil } func (p *pluginBuilder) Name() string { - return pluginName + return PluginName } func pluginConfigFromJSON(jd json.RawMessage) (Options, error) { diff --git a/internal/testutils/xds/e2e/setup_certs.go b/internal/testutils/xds/e2e/setup_certs.go index dea392162595..fb289487a4a9 100644 --- a/internal/testutils/xds/e2e/setup_certs.go +++ b/internal/testutils/xds/e2e/setup_certs.go @@ -98,7 +98,7 @@ func CreateClientTLSCredentials(t *testing.T) credentials.TransportCredentials { // CreateServerTLSCredentials creates server-side TLS transport credentials // using certificate and key files from testdata/x509 directory. -func CreateServerTLSCredentials(t *testing.T) credentials.TransportCredentials { +func CreateServerTLSCredentials(t *testing.T, clientAuth tls.ClientAuthType) credentials.TransportCredentials { t.Helper() cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) @@ -114,7 +114,7 @@ func CreateServerTLSCredentials(t *testing.T) credentials.TransportCredentials { t.Fatal("Failed to append certificates") } return credentials.NewTLS(&tls.Config{ - ClientAuth: tls.RequireAndVerifyClientCert, + ClientAuth: clientAuth, Certificates: []tls.Certificate{cert}, ClientCAs: ca, }) diff --git a/test/xds/xds_client_certificate_providers_test.go b/test/xds/xds_client_certificate_providers_test.go index a2979ca1beae..7741dc7581b9 100644 --- a/test/xds/xds_client_certificate_providers_test.go +++ b/test/xds/xds_client_certificate_providers_test.go @@ -20,6 +20,7 @@ package xds_test import ( "context" + "crypto/tls" "fmt" "strings" "testing" @@ -226,7 +227,7 @@ func (s) TestClientSideXDS_WithValidAndInvalidSecurityConfiguration(t *testing.T // backend1 configured with TLS creds, represents cluster1 // backend2 configured with insecure creds, represents cluster2 // backend3 configured with insecure creds, represents cluster3 - creds := e2e.CreateServerTLSCredentials(t) + creds := e2e.CreateServerTLSCredentials(t, tls.RequireAndVerifyClientCert) server1 := stubserver.StartTestService(t, nil, grpc.Creds(creds)) defer server1.Stop() server2 := stubserver.StartTestService(t, nil) diff --git a/xds/bootstrap/bootstrap.go b/xds/bootstrap/bootstrap.go index fcb99bdfd967..ef55ff0c02db 100644 --- a/xds/bootstrap/bootstrap.go +++ b/xds/bootstrap/bootstrap.go @@ -37,8 +37,10 @@ var registry = make(map[string]Credentials) // Credentials interface encapsulates a credentials.Bundle builder // that can be used for communicating with the xDS Management server. type Credentials interface { - // Build returns a credential bundle associated with this credential. - Build(config json.RawMessage) (credentials.Bundle, error) + // Build returns a credential bundle associated with this credential, and + // a function to cleans up additional resources associated with this bundle + // when it is no longer needed. + Build(config json.RawMessage) (credentials.Bundle, func(), error) // Name returns the credential name associated with this credential. Name() string } diff --git a/xds/bootstrap/bootstrap_test.go b/xds/bootstrap/bootstrap_test.go index 80ae31ccd2e3..1afc3ce7075a 100644 --- a/xds/bootstrap/bootstrap_test.go +++ b/xds/bootstrap/bootstrap_test.go @@ -36,9 +36,9 @@ type testCredsBuilder struct { config json.RawMessage } -func (t *testCredsBuilder) Build(config json.RawMessage) (credentials.Bundle, error) { +func (t *testCredsBuilder) Build(config json.RawMessage) (credentials.Bundle, func(), error) { t.config = config - return nil, nil + return nil, nil, nil } func (t *testCredsBuilder) Name() string { @@ -53,7 +53,7 @@ func TestRegisterNew(t *testing.T) { const sampleConfig = "sample_config" rawMessage := json.RawMessage(sampleConfig) - if _, err := c.Build(rawMessage); err != nil { + if _, _, err := c.Build(rawMessage); err != nil { t.Errorf("Build(%v) error = %v, want nil", rawMessage, err) } diff --git a/xds/internal/xdsclient/bootstrap/bootstrap.go b/xds/internal/xdsclient/bootstrap/bootstrap.go index 57fcb087b28b..0736a06d73a4 100644 --- a/xds/internal/xdsclient/bootstrap/bootstrap.go +++ b/xds/internal/xdsclient/bootstrap/bootstrap.go @@ -39,6 +39,7 @@ import ( "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/pretty" "google.golang.org/grpc/xds/bootstrap" + "google.golang.org/grpc/xds/internal/xdsclient/tlscreds" ) const ( @@ -60,6 +61,7 @@ const ( func init() { bootstrap.RegisterCredentials(&insecureCredsBuilder{}) bootstrap.RegisterCredentials(&googleDefaultCredsBuilder{}) + bootstrap.RegisterCredentials(&tlsCredsBuilder{}) } // For overriding in unit tests. @@ -69,20 +71,32 @@ var bootstrapFileReadFunc = os.ReadFile // package `xds/bootstrap` and encapsulates an insecure credential. type insecureCredsBuilder struct{} -func (i *insecureCredsBuilder) Build(json.RawMessage) (credentials.Bundle, error) { - return insecure.NewBundle(), nil +func (i *insecureCredsBuilder) Build(json.RawMessage) (credentials.Bundle, func(), error) { + return insecure.NewBundle(), func() {}, nil } func (i *insecureCredsBuilder) Name() string { return "insecure" } +// tlsCredsBuilder implements the `Credentials` interface defined in +// package `xds/bootstrap` and encapsulates a TLS credential. +type tlsCredsBuilder struct{} + +func (t *tlsCredsBuilder) Build(config json.RawMessage) (credentials.Bundle, func(), error) { + return tlscreds.NewBundle(config) +} + +func (t *tlsCredsBuilder) Name() string { + return "tls" +} + // googleDefaultCredsBuilder implements the `Credentials` interface defined in // package `xds/boostrap` and encapsulates a Google Default credential. type googleDefaultCredsBuilder struct{} -func (d *googleDefaultCredsBuilder) Build(json.RawMessage) (credentials.Bundle, error) { - return google.NewDefaultCredentials(), nil +func (d *googleDefaultCredsBuilder) Build(json.RawMessage) (credentials.Bundle, func(), error) { + return google.NewDefaultCredentials(), func() {}, nil } func (d *googleDefaultCredsBuilder) Name() string { @@ -151,6 +165,10 @@ type ServerConfig struct { // when a resource is deleted, nor will it remove the existing resource value // from its cache. IgnoreResourceDeletion bool + + // Cleanups are called when the xDS client for this server is closed. Allows + // cleaning up resources created specifically for this ServerConfig. + Cleanups []func() } // CredsDialOption returns the configured credentials as a grpc dial option. @@ -206,12 +224,13 @@ func (sc *ServerConfig) UnmarshalJSON(data []byte) error { if c == nil { continue } - bundle, err := c.Build(cc.Config) + bundle, cancel, err := c.Build(cc.Config) if err != nil { return fmt.Errorf("failed to build credentials bundle from bootstrap for %q: %v", cc.Type, err) } sc.Creds = ChannelCreds(cc) sc.credsDialOption = grpc.WithCredentialsBundle(bundle) + sc.Cleanups = append(sc.Cleanups, cancel) break } return nil diff --git a/xds/internal/xdsclient/bootstrap/bootstrap_test.go b/xds/internal/xdsclient/bootstrap/bootstrap_test.go index 84075743a8fe..a1138e2363d5 100644 --- a/xds/internal/xdsclient/bootstrap/bootstrap_test.go +++ b/xds/internal/xdsclient/bootstrap/bootstrap_test.go @@ -1008,30 +1008,53 @@ func TestServerConfigMarshalAndUnmarshal(t *testing.T) { } func TestDefaultBundles(t *testing.T) { - if c := bootstrap.GetCredentials("google_default"); c == nil { - t.Errorf(`bootstrap.GetCredentials("google_default") credential is nil, want non-nil`) - } + tests := []string{"google_default", "insecure", "tls"} - if c := bootstrap.GetCredentials("insecure"); c == nil { - t.Errorf(`bootstrap.GetCredentials("insecure") credential is nil, want non-nil`) + for _, typename := range tests { + t.Run(typename, func(t *testing.T) { + if c := bootstrap.GetCredentials(typename); c == nil { + t.Errorf(`bootstrap.GetCredentials(%s) credential is nil, want non-nil`, typename) + } + }) } } func TestCredsBuilders(t *testing.T) { - b := &googleDefaultCredsBuilder{} - if _, err := b.Build(nil); err != nil { - t.Errorf("googleDefaultCredsBuilder.Build failed: %v", err) + tests := []struct { + typename string + builder bootstrap.Credentials + }{ + {"google_default", &googleDefaultCredsBuilder{}}, + {"insecure", &insecureCredsBuilder{}}, + {"tls", &tlsCredsBuilder{}}, } - if got, want := b.Name(), "google_default"; got != want { - t.Errorf("googleDefaultCredsBuilder.Name = %v, want %v", got, want) + + for _, test := range tests { + t.Run(test.typename, func(t *testing.T) { + if got, want := test.builder.Name(), test.typename; got != want { + t.Errorf("%T.Name = %v, want %v", test.builder, got, want) + } + + _, stop, err := test.builder.Build(nil) + if err != nil { + t.Fatalf("%T.Build failed: %v", test.builder, err) + } + stop() + }) } +} - i := &insecureCredsBuilder{} - if _, err := i.Build(nil); err != nil { - t.Errorf("insecureCredsBuilder.Build failed: %v", err) +func TestTlsCredsBuilder(t *testing.T) { + tls := &tlsCredsBuilder{} + _, stop, err := tls.Build(json.RawMessage(`{}`)) + if err != nil { + t.Fatalf("tls.Build() failed with error %s when expected to succeed", err) } + stop() - if got, want := i.Name(), "insecure"; got != want { - t.Errorf("insecureCredsBuilder.Name = %v, want %v", got, want) + if _, stop, err := tls.Build(json.RawMessage(`{"ca_certificate_file":"/ca_certificates.pem","refresh_interval": "asdf"}`)); err == nil { + t.Errorf("tls.Build() succeeded with an invalid refresh interval, when expected to fail") + stop() } + // package internal/xdsclient/tlscreds has tests for config validity. } diff --git a/xds/internal/xdsclient/clientimpl.go b/xds/internal/xdsclient/clientimpl.go index 2c05ea66f5f9..1088b60301cb 100644 --- a/xds/internal/xdsclient/clientimpl.go +++ b/xds/internal/xdsclient/clientimpl.go @@ -85,5 +85,17 @@ func (c *clientImpl) close() { c.authorityMu.Unlock() c.serializerClose() + for _, f := range c.config.XDSServer.Cleanups { + f() + } + for _, a := range c.config.Authorities { + if a.XDSServer == nil { + // The server for this authority is the top-level one, cleaned up above. + continue + } + for _, f := range a.XDSServer.Cleanups { + f() + } + } c.logger.Infof("Shutdown") } diff --git a/xds/internal/xdsclient/tlscreds/bundle.go b/xds/internal/xdsclient/tlscreds/bundle.go new file mode 100644 index 000000000000..02da3dbf3496 --- /dev/null +++ b/xds/internal/xdsclient/tlscreds/bundle.go @@ -0,0 +1,138 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package tlscreds implements mTLS Credentials in xDS Bootstrap File. +// See gRFC A65: github.com/grpc/proposal/blob/master/A65-xds-mtls-creds-in-bootstrap.md. +package tlscreds + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "net" + + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/tls/certprovider" + "google.golang.org/grpc/credentials/tls/certprovider/pemfile" + "google.golang.org/grpc/internal/grpcsync" +) + +// bundle is an implementation of credentials.Bundle which implements mTLS +// Credentials in xDS Bootstrap File. +type bundle struct { + transportCredentials credentials.TransportCredentials +} + +// NewBundle returns a credentials.Bundle which implements mTLS Credentials in xDS +// Bootstrap File. It delegates certificate loading to a file_watcher provider +// if either client certificates or server root CA is specified. The second +// return value is a close func that should be called when the caller no longer +// needs this bundle. +// See gRFC A65: github.com/grpc/proposal/blob/master/A65-xds-mtls-creds-in-bootstrap.md +func NewBundle(jd json.RawMessage) (credentials.Bundle, func(), error) { + cfg := &struct { + CertificateFile string `json:"certificate_file"` + CACertificateFile string `json:"ca_certificate_file"` + PrivateKeyFile string `json:"private_key_file"` + }{} + + if jd != nil { + if err := json.Unmarshal(jd, cfg); err != nil { + return nil, nil, fmt.Errorf("failed to unmarshal config: %v", err) + } + } // Else the config field is absent. Treat it as an empty config. + + if cfg.CACertificateFile == "" && cfg.CertificateFile == "" && cfg.PrivateKeyFile == "" { + // We cannot use (and do not need) a file_watcher provider in this case, + // and can simply directly use the TLS transport credentials. + // Quoting A65: + // + // > The only difference between the file-watcher certificate provider + // > config and this one is that in the file-watcher certificate + // > provider, at least one of the "certificate_file" or + // > "ca_certificate_file" fields must be specified, whereas in this + // > configuration, it is acceptable to specify neither one. + return &bundle{transportCredentials: credentials.NewTLS(&tls.Config{})}, func() {}, nil + } + // Otherwise we need to use a file_watcher provider to watch the CA, + // private and public keys. + + // The pemfile plugin (file_watcher) currently ignores BuildOptions. + provider, err := certprovider.GetProvider(pemfile.PluginName, jd, certprovider.BuildOptions{}) + if err != nil { + return nil, nil, err + } + return &bundle{ + transportCredentials: &reloadingCreds{provider: provider}, + }, grpcsync.OnceFunc(func() { provider.Close() }), nil +} + +func (t *bundle) TransportCredentials() credentials.TransportCredentials { + return t.transportCredentials +} + +func (t *bundle) PerRPCCredentials() credentials.PerRPCCredentials { + // mTLS provides transport credentials only. There are no per-RPC + // credentials. + return nil +} + +func (t *bundle) NewWithMode(string) (credentials.Bundle, error) { + // This bundle has a single mode which only uses TLS transport credentials, + // so there is no legitimate case where callers would call NewWithMode. + return nil, fmt.Errorf("xDS TLS credentials only support one mode") +} + +// reloadingCreds is a credentials.TransportCredentials for client +// side mTLS that reloads the server root CA certificate and the client +// certificates from the provider on every client handshake. This is necessary +// because the standard TLS credentials do not support reloading CA +// certificates. +type reloadingCreds struct { + provider certprovider.Provider +} + +func (c *reloadingCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + km, err := c.provider.KeyMaterial(ctx) + if err != nil { + return nil, nil, err + } + config := &tls.Config{ + RootCAs: km.Roots, + Certificates: km.Certs, + } + return credentials.NewTLS(config).ClientHandshake(ctx, authority, rawConn) +} + +func (c *reloadingCreds) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{SecurityProtocol: "tls"} +} + +func (c *reloadingCreds) Clone() credentials.TransportCredentials { + return &reloadingCreds{provider: c.provider} +} + +func (c *reloadingCreds) OverrideServerName(string) error { + return errors.New("overriding server name is not supported by xDS client TLS credentials") +} + +func (c *reloadingCreds) ServerHandshake(net.Conn) (net.Conn, credentials.AuthInfo, error) { + return nil, nil, errors.New("server handshake is not supported by xDS client TLS credentials") +} diff --git a/xds/internal/xdsclient/tlscreds/bundle_ext_test.go b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go new file mode 100644 index 000000000000..bda7319d83ce --- /dev/null +++ b/xds/internal/xdsclient/tlscreds/bundle_ext_test.go @@ -0,0 +1,253 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package tlscreds_test + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "os" + "strings" + "testing" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils/xds/e2e" + testgrpc "google.golang.org/grpc/interop/grpc_testing" + testpb "google.golang.org/grpc/interop/grpc_testing" + "google.golang.org/grpc/status" + "google.golang.org/grpc/testdata" + "google.golang.org/grpc/xds/internal/xdsclient/tlscreds" +) + +const defaultTestTimeout = 5 * time.Second + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +type Closable interface { + Close() +} + +func (s) TestValidTlsBuilder(t *testing.T) { + caCert := testdata.Path("x509/server_ca_cert.pem") + clientCert := testdata.Path("x509/client1_cert.pem") + clientKey := testdata.Path("x509/client1_key.pem") + tests := []struct { + name string + jd string + }{ + { + name: "Absent configuration", + jd: `null`, + }, + { + name: "Empty configuration", + jd: `{}`, + }, + { + name: "Only CA certificate chain", + jd: fmt.Sprintf(`{"ca_certificate_file": "%s"}`, caCert), + }, + { + name: "Only private key and certificate chain", + jd: fmt.Sprintf(`{"certificate_file":"%s","private_key_file":"%s"}`, clientCert, clientKey), + }, + { + name: "CA chain, private key and certificate chain", + jd: fmt.Sprintf(`{"ca_certificate_file":"%s","certificate_file":"%s","private_key_file":"%s"}`, caCert, clientCert, clientKey), + }, + { + name: "Only refresh interval", jd: `{"refresh_interval": "1s"}`, + }, + { + name: "Refresh interval and CA certificate chain", + jd: fmt.Sprintf(`{"refresh_interval": "1s","ca_certificate_file": "%s"}`, caCert), + }, + { + name: "Refresh interval, private key and certificate chain", + jd: fmt.Sprintf(`{"refresh_interval": "1s","certificate_file":"%s","private_key_file":"%s"}`, clientCert, clientKey), + }, + { + name: "Refresh interval, CA chain, private key and certificate chain", + jd: fmt.Sprintf(`{"refresh_interval": "1s","ca_certificate_file":"%s","certificate_file":"%s","private_key_file":"%s"}`, caCert, clientCert, clientKey), + }, + { + name: "Unknown field", + jd: `{"unknown_field": "foo"}`, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + msg := json.RawMessage(test.jd) + _, stop, err := tlscreds.NewBundle(msg) + if err != nil { + t.Fatalf("NewBundle(%s) returned error %s when expected to succeed", test.jd, err) + } + stop() + }) + } +} + +func (s) TestInvalidTlsBuilder(t *testing.T) { + tests := []struct { + name, jd, wantErrPrefix string + }{ + { + name: "Wrong type in json", + jd: `{"ca_certificate_file": 1}`, + wantErrPrefix: "failed to unmarshal config:"}, + { + name: "Missing private key", + jd: fmt.Sprintf(`{"certificate_file":"%s"}`, testdata.Path("x509/server_cert.pem")), + wantErrPrefix: "pemfile: private key file and identity cert file should be both specified or not specified", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + msg := json.RawMessage(test.jd) + _, stop, err := tlscreds.NewBundle(msg) + if err == nil || !strings.HasPrefix(err.Error(), test.wantErrPrefix) { + if stop != nil { + stop() + } + t.Fatalf("NewBundle(%s): got error %s, want an error with prefix %s", msg, err, test.wantErrPrefix) + } + }) + } +} + +func (s) TestCaReloading(t *testing.T) { + serverCa, err := os.ReadFile(testdata.Path("x509/server_ca_cert.pem")) + if err != nil { + t.Fatalf("Failed to read test CA cert: %s", err) + } + + // Write CA certs to a temporary file so that we can modify it later. + caPath := t.TempDir() + "/ca.pem" + if err = os.WriteFile(caPath, serverCa, 0644); err != nil { + t.Fatalf("Failed to write test CA cert: %v", err) + } + cfg := fmt.Sprintf(`{ + "ca_certificate_file": "%s", + "refresh_interval": ".01s" + }`, caPath) + tlsBundle, stop, err := tlscreds.NewBundle([]byte(cfg)) + if err != nil { + t.Fatalf("Failed to create TLS bundle: %v", err) + } + defer stop() + + serverCredentials := grpc.Creds(e2e.CreateServerTLSCredentials(t, tls.NoClientCert)) + server := stubserver.StartTestService(t, nil, serverCredentials) + + conn, err := grpc.Dial( + server.Address, + grpc.WithCredentialsBundle(tlsBundle), + grpc.WithAuthority("x.test.example.com"), + ) + if err != nil { + t.Fatalf("Error dialing: %v", err) + } + defer conn.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + client := testgrpc.NewTestServiceClient(conn) + if _, err = client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Errorf("Error calling EmptyCall: %v", err) + } + // close the server and create a new one to force client to do a new + // handshake. + server.Stop() + + invalidCa, err := os.ReadFile(testdata.Path("ca.pem")) + if err != nil { + t.Fatalf("Failed to read test CA cert: %v", err) + } + // unload root cert + err = os.WriteFile(caPath, invalidCa, 0644) + if err != nil { + t.Fatalf("Failed to write test CA cert: %v", err) + } + + for ; ctx.Err() == nil; <-time.After(10 * time.Millisecond) { + ss := stubserver.StubServer{ + Address: server.Address, + EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil }, + } + server = stubserver.StartTestService(t, &ss, serverCredentials) + + // Client handshake should eventually fail because the client CA was + // reloaded, and thus the server cert is signed by an unknown CA. + t.Log(server) + _, err = client.EmptyCall(ctx, &testpb.Empty{}) + const wantErr = "certificate signed by unknown authority" + if status.Code(err) == codes.Unavailable && strings.Contains(err.Error(), wantErr) { + // Certs have reloaded. + server.Stop() + break + } + t.Logf("EmptyCall() got err: %s, want code: %s, want err: %s", err, codes.Unavailable, wantErr) + server.Stop() + } + if ctx.Err() != nil { + t.Errorf("Timed out waiting for CA certs reloading") + } +} + +func (s) TestMTLS(t *testing.T) { + s := stubserver.StartTestService(t, nil, grpc.Creds(e2e.CreateServerTLSCredentials(t, tls.RequireAndVerifyClientCert))) + defer s.Stop() + + cfg := fmt.Sprintf(`{ + "ca_certificate_file": "%s", + "certificate_file": "%s", + "private_key_file": "%s" + }`, + testdata.Path("x509/server_ca_cert.pem"), + testdata.Path("x509/client1_cert.pem"), + testdata.Path("x509/client1_key.pem")) + tlsBundle, stop, err := tlscreds.NewBundle([]byte(cfg)) + if err != nil { + t.Fatalf("Failed to create TLS bundle: %v", err) + } + defer stop() + conn, err := grpc.Dial(s.Address, grpc.WithCredentialsBundle(tlsBundle), grpc.WithAuthority("x.test.example.com")) + if err != nil { + t.Fatalf("Error dialing: %v", err) + } + defer conn.Close() + client := testgrpc.NewTestServiceClient(conn) + if _, err = client.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + t.Errorf("EmptyCall(): got error %v when expected to succeed", err) + } +} diff --git a/xds/internal/xdsclient/tlscreds/bundle_test.go b/xds/internal/xdsclient/tlscreds/bundle_test.go new file mode 100644 index 000000000000..ad50508aeb94 --- /dev/null +++ b/xds/internal/xdsclient/tlscreds/bundle_test.go @@ -0,0 +1,92 @@ +/* + * + * Copyright 2023 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package tlscreds + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "strings" + "testing" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/tls/certprovider" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils/xds/e2e" + testgrpc "google.golang.org/grpc/interop/grpc_testing" + testpb "google.golang.org/grpc/interop/grpc_testing" + "google.golang.org/grpc/testdata" +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +type failingProvider struct{} + +func (f failingProvider) KeyMaterial(context.Context) (*certprovider.KeyMaterial, error) { + return nil, errors.New("test error") +} + +func (f failingProvider) Close() {} + +func (s) TestFailingProvider(t *testing.T) { + s := stubserver.StartTestService(t, nil, grpc.Creds(e2e.CreateServerTLSCredentials(t, tls.RequireAndVerifyClientCert))) + defer s.Stop() + + cfg := fmt.Sprintf(`{ + "ca_certificate_file": "%s", + "certificate_file": "%s", + "private_key_file": "%s" + }`, + testdata.Path("x509/server_ca_cert.pem"), + testdata.Path("x509/client1_cert.pem"), + testdata.Path("x509/client1_key.pem")) + tlsBundle, stop, err := NewBundle([]byte(cfg)) + if err != nil { + t.Fatalf("Failed to create TLS bundle: %v", err) + } + stop() + + // Force a provider that returns an error, and make sure the client fails + // the handshake. + creds, ok := tlsBundle.TransportCredentials().(*reloadingCreds) + if !ok { + t.Fatalf("Got %T, expected reloadingCreds", tlsBundle.TransportCredentials()) + } + creds.provider = &failingProvider{} + + conn, err := grpc.Dial(s.Address, grpc.WithCredentialsBundle(tlsBundle), grpc.WithAuthority("x.test.example.com")) + if err != nil { + t.Fatalf("Error dialing: %v", err) + } + defer conn.Close() + + client := testgrpc.NewTestServiceClient(conn) + _, err = client.EmptyCall(context.Background(), &testpb.Empty{}) + if wantErr := "test error"; err == nil || !strings.Contains(err.Error(), wantErr) { + t.Errorf("EmptyCall() got err: %s, want err to contain: %s", err, wantErr) + } +}