diff --git a/core/comm/server.go b/core/comm/server.go index 276523d78ce..146d6e67663 100644 --- a/core/comm/server.go +++ b/core/comm/server.go @@ -53,6 +53,9 @@ type GRPCServer interface { //TLSEnabled is a flag indicating whether or not TLS is enabled for this //GRPCServer instance TLSEnabled() bool + //MutualTLSRequired is a flag indicating whether or not client certificates + //are required for this GRPCServer instance + MutualTLSRequired() bool //AppendClientRootCAs appends PEM-encoded X509 certificate authorities to //the list of authorities used to verify client certificates AppendClientRootCAs(clientRoots [][]byte) error @@ -87,6 +90,8 @@ type grpcServerImpl struct { tlsConfig *tls.Config //Is TLS enabled? tlsEnabled bool + //Are client certifictes required + mutualTLSRequired bool } //NewGRPCServer creates a new implementation of a GRPCServer given a @@ -159,6 +164,7 @@ func newGRPCServerFromListenerWithKa(listener net.Listener, secureConfig SecureS grpcServer.tlsConfig.ClientAuth = tls.RequestClientCert //check if client authentication is required if secureConfig.RequireClientCert { + grpcServer.mutualTLSRequired = true //require TLS client auth grpcServer.tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert //if we have client root CAs, create a certPool @@ -219,6 +225,12 @@ func (gServer *grpcServerImpl) TLSEnabled() bool { return gServer.tlsEnabled } +//MutualTLSRequired is a flag indicating whether or not client certificates +//are required for this GRPCServer instance +func (gServer *grpcServerImpl) MutualTLSRequired() bool { + return gServer.mutualTLSRequired +} + //Start starts the underlying grpc.Server func (gServer *grpcServerImpl) Start() error { return gServer.server.Serve(gServer.listener) diff --git a/core/comm/server_test.go b/core/comm/server_test.go index ecdfb50144d..00422c975b1 100644 --- a/core/comm/server_test.go +++ b/core/comm/server_test.go @@ -489,8 +489,10 @@ func TestNewGRPCServer(t *testing.T) { assert.Equal(t, srv.Address(), addr.String()) assert.Equal(t, srv.Listener().Addr().String(), addr.String()) - //TlSEnabled should be false + //TLSEnabled should be false assert.Equal(t, srv.TLSEnabled(), false) + //MutualTLSRequired should be false + assert.Equal(t, srv.MutualTLSRequired(), false) //register the GRPC test server testpb.RegisterTestServiceServer(srv.Server(), &testServiceServer{}) @@ -542,8 +544,10 @@ func TestNewGRPCServerFromListener(t *testing.T) { assert.Equal(t, srv.Address(), addr.String()) assert.Equal(t, srv.Listener().Addr().String(), addr.String()) - //TlSEnabled should be false + //TLSEnabled should be false assert.Equal(t, srv.TLSEnabled(), false) + //MutualTLSRequired should be false + assert.Equal(t, srv.MutualTLSRequired(), false) //register the GRPC test server testpb.RegisterTestServiceServer(srv.Server(), &testServiceServer{}) @@ -594,8 +598,10 @@ func TestNewSecureGRPCServer(t *testing.T) { cert, _ := tls.X509KeyPair([]byte(selfSignedCertPEM), []byte(selfSignedKeyPEM)) assert.Equal(t, srv.ServerCertificate(), cert) - //TlSEnabled should be true + //TLSEnabled should be true assert.Equal(t, srv.TLSEnabled(), true) + //MutualTLSRequired should be false + assert.Equal(t, srv.MutualTLSRequired(), false) //register the GRPC test server testpb.RegisterTestServiceServer(srv.Server(), &testServiceServer{}) @@ -677,8 +683,10 @@ func TestNewSecureGRPCServerFromListener(t *testing.T) { cert, _ := tls.X509KeyPair([]byte(selfSignedCertPEM), []byte(selfSignedKeyPEM)) assert.Equal(t, srv.ServerCertificate(), cert) - //TlSEnabled should be true + //TLSEnabled should be true assert.Equal(t, srv.TLSEnabled(), true) + //MutualTLSRequired should be false + assert.Equal(t, srv.MutualTLSRequired(), false) //register the GRPC test server testpb.RegisterTestServiceServer(srv.Server(), &testServiceServer{}) @@ -894,6 +902,9 @@ func runMutualAuth(t *testing.T, servers []testServer, trustedClients, unTrusted return err } + //MutualTLSRequired should be true + assert.Equal(t, srv.MutualTLSRequired(), true) + //register the GRPC test server and start the GRPCServer testpb.RegisterTestServiceServer(srv.Server(), &testServiceServer{}) go srv.Start()