Skip to content

Commit

Permalink
[FAB-17540] Fix for race read/write tlsconfig (#1050)
Browse files Browse the repository at this point in the history
- Encapsulate tls.Config in a struct with a lock to prevent data races
when accesing tls.Config
- Fix TestCreds assertion on creds.Clone() was invalid because of
tls.Config's mutex

Signed-off-by: Tiffany Harris <tiffany.harris@ibm.com>
  • Loading branch information
stephyee authored Apr 10, 2020
1 parent 856f215 commit b03b3d9
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 245 deletions.
3 changes: 2 additions & 1 deletion core/comm/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,8 @@ func TestNewConnection(t *testing.T) {
assert.NotNil(t, conn)
} else {
t.Log(errors.WithStack(err))
assert.Contains(t, err.Error(), test.errorMsg)
assert.Error(t, err)
assert.Regexp(t, test.errorMsg, err)
}
})
}
Expand Down
56 changes: 49 additions & 7 deletions core/comm/creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ package comm
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"net"
"sync"

"github.com/hyperledger/fabric/common/flogging"
"google.golang.org/grpc/credentials"
Expand All @@ -33,26 +35,62 @@ var (
// NewServerTransportCredentials returns a new initialized
// grpc/credentials.TransportCredentials
func NewServerTransportCredentials(
serverConfig *tls.Config,
serverConfig *TLSConfig,
logger *flogging.FabricLogger) credentials.TransportCredentials {

// NOTE: unlike the default grpc/credentials implementation, we do not
// clone the tls.Config which allows us to update it dynamically
serverConfig.NextProtos = alpnProtoStr
serverConfig.config.NextProtos = alpnProtoStr
// override TLS version and ensure it is 1.2
serverConfig.MinVersion = tls.VersionTLS12
serverConfig.MaxVersion = tls.VersionTLS12
serverConfig.config.MinVersion = tls.VersionTLS12
serverConfig.config.MaxVersion = tls.VersionTLS12
return &serverCreds{
serverConfig: serverConfig,
logger: logger}
}

// serverCreds is an implementation of grpc/credentials.TransportCredentials.
type serverCreds struct {
serverConfig *tls.Config
serverConfig *TLSConfig
logger *flogging.FabricLogger
}

type TLSConfig struct {
config *tls.Config
lock sync.RWMutex
}

func NewTLSConfig(config *tls.Config) *TLSConfig {
return &TLSConfig{
config: config,
}
}

func (t *TLSConfig) Config() tls.Config {
t.lock.RLock()
defer t.lock.RUnlock()

if t.config != nil {
return *t.config.Clone()
}

return tls.Config{}
}

func (t *TLSConfig) AddClientRootCA(cert *x509.Certificate) {
t.lock.Lock()
defer t.lock.Unlock()

t.config.ClientCAs.AddCert(cert)
}

func (t *TLSConfig) SetClientCAs(certPool *x509.CertPool) {
t.lock.Lock()
defer t.lock.Unlock()

t.config.ClientCAs = certPool
}

// ClientHandShake is not implemented for `serverCreds`.
func (sc *serverCreds) ClientHandshake(context.Context,
string, net.Conn) (net.Conn, credentials.AuthInfo, error) {
Expand All @@ -61,7 +99,9 @@ func (sc *serverCreds) ClientHandshake(context.Context,

// ServerHandshake does the authentication handshake for servers.
func (sc *serverCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
conn := tls.Server(rawConn, sc.serverConfig)
serverConfig := sc.serverConfig.Config()

conn := tls.Server(rawConn, &serverConfig)
if err := conn.Handshake(); err != nil {
if sc.logger != nil {
sc.logger.With("remote address",
Expand All @@ -82,7 +122,9 @@ func (sc *serverCreds) Info() credentials.ProtocolInfo {

// Clone makes a copy of this TransportCredentials.
func (sc *serverCreds) Clone() credentials.TransportCredentials {
creds := NewServerTransportCredentials(sc.serverConfig, sc.logger)
config := sc.serverConfig.Config()
serverConfig := NewTLSConfig(&config)
creds := NewServerTransportCredentials(serverConfig, sc.logger)
return creds
}

Expand Down
82 changes: 79 additions & 3 deletions core/comm/creds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ func TestCreds(t *testing.T) {
Certificates: []tls.Certificate{cert},
}

config := comm.NewTLSConfig(tlsConfig)

logger, recorder := floggingtest.NewTestLogger(t)

creds := comm.NewServerTransportCredentials(tlsConfig, logger)
creds := comm.NewServerTransportCredentials(config, logger)
_, _, err = creds.ClientHandshake(nil, "", nil)
assert.EqualError(t, err, comm.ClientHandshakeNotImplError.Error())
err = creds.OverrideServerName("")
assert.EqualError(t, err, comm.OverrrideHostnameNotSupportedError.Error())
clone := creds.Clone()
assert.Equal(t, creds, clone)
assert.Equal(t, "1.2", creds.Info().SecurityVersion)
assert.Equal(t, "tls", creds.Info().SecurityProtocol)

Expand Down Expand Up @@ -95,3 +95,79 @@ func TestCreds(t *testing.T) {
assert.Contains(t, err.Error(), "protocol version not supported")
assert.Contains(t, recorder.Messages()[0], "TLS handshake failed with error")
}

func TestNewTLSConfig(t *testing.T) {
t.Parallel()
tlsConfig := &tls.Config{}

config := comm.NewTLSConfig(tlsConfig)

assert.NotEmpty(t, config, "TLSConfig is not empty")
}

func TestConfig(t *testing.T) {
t.Parallel()
config := comm.NewTLSConfig(&tls.Config{
ServerName: "bueno",
})

configCopy := config.Config()

certPool := x509.NewCertPool()
config.SetClientCAs(certPool)

assert.NotEqual(t, config.Config(), &configCopy, "TLSConfig should have new certs")
}

func TestAddRootCA(t *testing.T) {
t.Parallel()

caPEM, err := ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-cert.pem"))
if err != nil {
t.Fatalf("failed to read root certificate: %v", err)
}

cert := &x509.Certificate{
EmailAddresses: []string{"test@foobar.com"},
}

expectedCertPool := x509.NewCertPool()
ok := expectedCertPool.AppendCertsFromPEM(caPEM)
if !ok {
t.Fatalf("failed to create expected certPool")
}

expectedCertPool.AddCert(cert)

certPool := x509.NewCertPool()
ok = certPool.AppendCertsFromPEM(caPEM)
if !ok {
t.Fatalf("failed to create certPool")
}

tlsConfig := &tls.Config{
ClientCAs: certPool,
}
config := comm.NewTLSConfig(tlsConfig)

assert.Equal(t, config.Config().ClientCAs, certPool)

config.AddClientRootCA(cert)

assert.Equal(t, config.Config().ClientCAs, expectedCertPool, "The CertPools should be equal")
}

func TestSetClientCAs(t *testing.T) {
t.Parallel()
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{},
}
config := comm.NewTLSConfig(tlsConfig)

assert.Empty(t, config.Config().ClientCAs, "No CertPool should be defined")

certPool := x509.NewCertPool()
config.SetClientCAs(certPool)

assert.NotNil(t, config.Config().ClientCAs, "The CertPools' should not be the same")
}
42 changes: 14 additions & 28 deletions core/comm/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type GRPCServer struct {
// the tlsConfig.ClientCAs indexed by subject
clientRootCAs map[string]*x509.Certificate
// TLS configuration used by the grpc server
tlsConfig *tls.Config
tls *TLSConfig
}

// NewGRPCServer creates a new implementation of a GRPCServer given a
Expand Down Expand Up @@ -92,28 +92,28 @@ func NewGRPCServerFromListener(listener net.Listener, serverConfig ServerConfig)
return &cert, nil
}
//base server certificate
grpcServer.tlsConfig = &tls.Config{
grpcServer.tls = NewTLSConfig(&tls.Config{
VerifyPeerCertificate: secureConfig.VerifyCertificate,
GetCertificate: getCert,
SessionTicketsDisabled: true,
CipherSuites: secureConfig.CipherSuites,
}
})

if serverConfig.SecOpts.TimeShift > 0 {
timeShift := serverConfig.SecOpts.TimeShift
grpcServer.tlsConfig.Time = func() time.Time {
grpcServer.tls.config.Time = func() time.Time {
return time.Now().Add((-1) * timeShift)
}
}
grpcServer.tlsConfig.ClientAuth = tls.RequestClientCert
grpcServer.tls.config.ClientAuth = tls.RequestClientCert
//check if client authentication is required
if secureConfig.RequireClientCert {
//require TLS client auth
grpcServer.tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
grpcServer.tls.config.ClientAuth = tls.RequireAndVerifyClientCert
//if we have client root CAs, create a certPool
if len(secureConfig.ClientRootCAs) > 0 {
grpcServer.clientRootCAs = make(map[string]*x509.Certificate)
grpcServer.tlsConfig.ClientCAs = x509.NewCertPool()
grpcServer.tls.config.ClientCAs = x509.NewCertPool()
for _, clientRootCA := range secureConfig.ClientRootCAs {
err = grpcServer.appendClientRootCA(clientRootCA)
if err != nil {
Expand All @@ -124,7 +124,7 @@ func NewGRPCServerFromListener(listener net.Listener, serverConfig ServerConfig)
}

// create credentials and add to server options
creds := NewServerTransportCredentials(grpcServer.tlsConfig, serverConfig.Logger)
creds := NewServerTransportCredentials(grpcServer.tls, serverConfig.Logger)
serverOpts = append(serverOpts, grpc.Creds(creds))
} else {
return nil, errors.New("serverConfig.SecOpts must contain both Key and Certificate when UseTLS is true")
Expand Down Expand Up @@ -193,14 +193,14 @@ func (gServer *GRPCServer) ServerCertificate() tls.Certificate {
// TLSEnabled is a flag indicating whether or not TLS is enabled for the
// GRPCServer instance
func (gServer *GRPCServer) TLSEnabled() bool {
return gServer.tlsConfig != nil
return gServer.tls != nil
}

// MutualTLSRequired is a flag indicating whether or not client certificates
// are required for this GRPCServer instance
func (gServer *GRPCServer) MutualTLSRequired() bool {
return gServer.tlsConfig != nil &&
gServer.tlsConfig.ClientAuth ==
return gServer.TLSEnabled() &&
gServer.tls.Config().ClientAuth ==
tls.RequireAndVerifyClientCert
}

Expand All @@ -214,20 +214,6 @@ func (gServer *GRPCServer) Stop() {
gServer.server.Stop()
}

// AppendClientRootCAs appends PEM-encoded X509 certificate authorities to
// the list of authorities used to verify client certificates
func (gServer *GRPCServer) AppendClientRootCAs(clientRoots [][]byte) error {
gServer.lock.Lock()
defer gServer.lock.Unlock()
for _, clientRoot := range clientRoots {
err := gServer.appendClientRootCA(clientRoot)
if err != nil {
return err
}
}
return nil
}

// internal function to add a PEM-encoded clientRootCA
func (gServer *GRPCServer) appendClientRootCA(clientRoot []byte) error {

Expand All @@ -244,7 +230,7 @@ func (gServer *GRPCServer) appendClientRootCA(clientRoot []byte) error {

for i, cert := range certs {
//first add to the ClientCAs
gServer.tlsConfig.ClientCAs.AddCert(cert)
gServer.tls.AddClientRootCA(cert)
//add it to our clientRootCAs map using subject as key
gServer.clientRootCAs[subjects[i]] = cert
}
Expand All @@ -271,7 +257,7 @@ func (gServer *GRPCServer) RemoveClientRootCAs(clientRoots [][]byte) error {
}

//replace the current ClientCAs pool
gServer.tlsConfig.ClientCAs = certPool
gServer.tls.SetClientCAs(certPool)
return nil
}

Expand Down Expand Up @@ -330,6 +316,6 @@ func (gServer *GRPCServer) SetClientRootCAs(clientRoots [][]byte) error {
//replace the internal map
gServer.clientRootCAs = clientRootCAs
//replace the current ClientCAs pool
gServer.tlsConfig.ClientCAs = certPool
gServer.tls.SetClientCAs(certPool)
return nil
}
Loading

0 comments on commit b03b3d9

Please sign in to comment.