Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FAB-17540] Fix for race read/write tlsconfig #1050

Merged
merged 1 commit into from
Apr 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion core/comm/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ func TestNewConnection(t *testing.T) {
serverTLS: &tls.Config{
Certificates: []tls.Certificate{testCerts.serverCert},
ClientAuth: tls.RequireAndVerifyClientCert,
MaxVersion: tls.VersionTLS12, // https://github.com/golang/go/issues/33368
},
success: false,
errorMsg: "tls: bad certificate",
Expand Down Expand Up @@ -358,7 +359,8 @@ func TestNewConnection(t *testing.T) {
assert.NotNil(t, conn)
} else {
t.Log(errors.WithStack(err))
assert.Contains(t, err.Error(), test.errorMsg)
assert.Error(t, err)
assert.Regexp(t, test.errorMsg, err)
}
})
}
Expand Down
56 changes: 49 additions & 7 deletions core/comm/creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ package comm
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"net"
"sync"

"github.com/hyperledger/fabric/common/flogging"
"google.golang.org/grpc/credentials"
Expand All @@ -33,26 +35,62 @@ var (
// NewServerTransportCredentials returns a new initialized
// grpc/credentials.TransportCredentials
func NewServerTransportCredentials(
serverConfig *tls.Config,
serverConfig *TLSConfig,
logger *flogging.FabricLogger) credentials.TransportCredentials {

// NOTE: unlike the default grpc/credentials implementation, we do not
// clone the tls.Config which allows us to update it dynamically
serverConfig.NextProtos = alpnProtoStr
serverConfig.config.NextProtos = alpnProtoStr
// override TLS version and ensure it is 1.2
serverConfig.MinVersion = tls.VersionTLS12
serverConfig.MaxVersion = tls.VersionTLS12
serverConfig.config.MinVersion = tls.VersionTLS12
serverConfig.config.MaxVersion = tls.VersionTLS12
return &serverCreds{
serverConfig: serverConfig,
logger: logger}
}

// serverCreds is an implementation of grpc/credentials.TransportCredentials.
type serverCreds struct {
serverConfig *tls.Config
serverConfig *TLSConfig
logger *flogging.FabricLogger
}

type TLSConfig struct {
config *tls.Config
lock sync.RWMutex
}

func NewTLSConfig(config *tls.Config) *TLSConfig {
return &TLSConfig{
config: config,
}
}

func (t *TLSConfig) Config() tls.Config {
t.lock.RLock()
defer t.lock.RUnlock()

if t.config != nil {
return *t.config.Clone()
}

return tls.Config{}
}

func (t *TLSConfig) AddClientRootCA(cert *x509.Certificate) {
t.lock.Lock()
defer t.lock.Unlock()

t.config.ClientCAs.AddCert(cert)
}

func (t *TLSConfig) SetClientCAs(certPool *x509.CertPool) {
t.lock.Lock()
defer t.lock.Unlock()

t.config.ClientCAs = certPool
}

// ClientHandShake is not implemented for `serverCreds`.
func (sc *serverCreds) ClientHandshake(context.Context,
string, net.Conn) (net.Conn, credentials.AuthInfo, error) {
Expand All @@ -61,7 +99,9 @@ func (sc *serverCreds) ClientHandshake(context.Context,

// ServerHandshake does the authentication handshake for servers.
func (sc *serverCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
conn := tls.Server(rawConn, sc.serverConfig)
serverConfig := sc.serverConfig.Config()

conn := tls.Server(rawConn, &serverConfig)
if err := conn.Handshake(); err != nil {
if sc.logger != nil {
sc.logger.With("remote address",
Expand All @@ -82,7 +122,9 @@ func (sc *serverCreds) Info() credentials.ProtocolInfo {

// Clone makes a copy of this TransportCredentials.
func (sc *serverCreds) Clone() credentials.TransportCredentials {
creds := NewServerTransportCredentials(sc.serverConfig, sc.logger)
config := sc.serverConfig.Config()
serverConfig := NewTLSConfig(&config)
creds := NewServerTransportCredentials(serverConfig, sc.logger)
return creds
}

Expand Down
82 changes: 79 additions & 3 deletions core/comm/creds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ func TestCreds(t *testing.T) {
Certificates: []tls.Certificate{cert},
}

config := comm.NewTLSConfig(tlsConfig)

logger, recorder := floggingtest.NewTestLogger(t)

creds := comm.NewServerTransportCredentials(tlsConfig, logger)
creds := comm.NewServerTransportCredentials(config, logger)
_, _, err = creds.ClientHandshake(nil, "", nil)
assert.EqualError(t, err, comm.ClientHandshakeNotImplError.Error())
err = creds.OverrideServerName("")
assert.EqualError(t, err, comm.OverrrideHostnameNotSupportedError.Error())
clone := creds.Clone()
assert.Equal(t, creds, clone)
assert.Equal(t, "1.2", creds.Info().SecurityVersion)
assert.Equal(t, "tls", creds.Info().SecurityProtocol)

Expand Down Expand Up @@ -95,3 +95,79 @@ func TestCreds(t *testing.T) {
assert.Contains(t, err.Error(), "protocol version not supported")
assert.Contains(t, recorder.Messages()[0], "TLS handshake failed with error")
}

func TestNewTLSConfig(t *testing.T) {
t.Parallel()
tlsConfig := &tls.Config{}

config := comm.NewTLSConfig(tlsConfig)

assert.NotEmpty(t, config, "TLSConfig is not empty")
}

func TestConfig(t *testing.T) {
t.Parallel()
config := comm.NewTLSConfig(&tls.Config{
ServerName: "bueno",
})

configCopy := config.Config()

certPool := x509.NewCertPool()
config.SetClientCAs(certPool)

assert.NotEqual(t, config.Config(), &configCopy, "TLSConfig should have new certs")
}

func TestAddRootCA(t *testing.T) {
t.Parallel()

caPEM, err := ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-cert.pem"))
if err != nil {
t.Fatalf("failed to read root certificate: %v", err)
}

cert := &x509.Certificate{
EmailAddresses: []string{"test@foobar.com"},
}

expectedCertPool := x509.NewCertPool()
ok := expectedCertPool.AppendCertsFromPEM(caPEM)
if !ok {
t.Fatalf("failed to create expected certPool")
}

expectedCertPool.AddCert(cert)

certPool := x509.NewCertPool()
ok = certPool.AppendCertsFromPEM(caPEM)
if !ok {
t.Fatalf("failed to create certPool")
}

tlsConfig := &tls.Config{
ClientCAs: certPool,
}
config := comm.NewTLSConfig(tlsConfig)

assert.Equal(t, config.Config().ClientCAs, certPool)

config.AddClientRootCA(cert)

assert.Equal(t, config.Config().ClientCAs, expectedCertPool, "The CertPools should be equal")
}

func TestSetClientCAs(t *testing.T) {
t.Parallel()
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{},
}
config := comm.NewTLSConfig(tlsConfig)

assert.Empty(t, config.Config().ClientCAs, "No CertPool should be defined")

certPool := x509.NewCertPool()
config.SetClientCAs(certPool)

assert.NotNil(t, config.Config().ClientCAs, "The CertPools' should not be the same")
}
42 changes: 14 additions & 28 deletions core/comm/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type GRPCServer struct {
// the tlsConfig.ClientCAs indexed by subject
clientRootCAs map[string]*x509.Certificate
// TLS configuration used by the grpc server
tlsConfig *tls.Config
tls *TLSConfig
}

// NewGRPCServer creates a new implementation of a GRPCServer given a
Expand Down Expand Up @@ -92,28 +92,28 @@ func NewGRPCServerFromListener(listener net.Listener, serverConfig ServerConfig)
return &cert, nil
}
//base server certificate
grpcServer.tlsConfig = &tls.Config{
grpcServer.tls = NewTLSConfig(&tls.Config{
VerifyPeerCertificate: secureConfig.VerifyCertificate,
GetCertificate: getCert,
SessionTicketsDisabled: true,
CipherSuites: secureConfig.CipherSuites,
}
})

if serverConfig.SecOpts.TimeShift > 0 {
timeShift := serverConfig.SecOpts.TimeShift
grpcServer.tlsConfig.Time = func() time.Time {
grpcServer.tls.config.Time = func() time.Time {
return time.Now().Add((-1) * timeShift)
}
}
grpcServer.tlsConfig.ClientAuth = tls.RequestClientCert
grpcServer.tls.config.ClientAuth = tls.RequestClientCert
//check if client authentication is required
if secureConfig.RequireClientCert {
//require TLS client auth
grpcServer.tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
grpcServer.tls.config.ClientAuth = tls.RequireAndVerifyClientCert
//if we have client root CAs, create a certPool
if len(secureConfig.ClientRootCAs) > 0 {
grpcServer.clientRootCAs = make(map[string]*x509.Certificate)
grpcServer.tlsConfig.ClientCAs = x509.NewCertPool()
grpcServer.tls.config.ClientCAs = x509.NewCertPool()
for _, clientRootCA := range secureConfig.ClientRootCAs {
err = grpcServer.appendClientRootCA(clientRootCA)
if err != nil {
Expand All @@ -124,7 +124,7 @@ func NewGRPCServerFromListener(listener net.Listener, serverConfig ServerConfig)
}

// create credentials and add to server options
creds := NewServerTransportCredentials(grpcServer.tlsConfig, serverConfig.Logger)
creds := NewServerTransportCredentials(grpcServer.tls, serverConfig.Logger)
serverOpts = append(serverOpts, grpc.Creds(creds))
} else {
return nil, errors.New("serverConfig.SecOpts must contain both Key and Certificate when UseTLS is true")
Expand Down Expand Up @@ -193,14 +193,14 @@ func (gServer *GRPCServer) ServerCertificate() tls.Certificate {
// TLSEnabled is a flag indicating whether or not TLS is enabled for the
// GRPCServer instance
func (gServer *GRPCServer) TLSEnabled() bool {
return gServer.tlsConfig != nil
return gServer.tls != nil
}

// MutualTLSRequired is a flag indicating whether or not client certificates
// are required for this GRPCServer instance
func (gServer *GRPCServer) MutualTLSRequired() bool {
return gServer.tlsConfig != nil &&
gServer.tlsConfig.ClientAuth ==
return gServer.TLSEnabled() &&
gServer.tls.Config().ClientAuth ==
tls.RequireAndVerifyClientCert
}

Expand All @@ -214,20 +214,6 @@ func (gServer *GRPCServer) Stop() {
gServer.server.Stop()
}

// AppendClientRootCAs appends PEM-encoded X509 certificate authorities to
// the list of authorities used to verify client certificates
func (gServer *GRPCServer) AppendClientRootCAs(clientRoots [][]byte) error {
gServer.lock.Lock()
defer gServer.lock.Unlock()
for _, clientRoot := range clientRoots {
err := gServer.appendClientRootCA(clientRoot)
if err != nil {
return err
}
}
return nil
}

// internal function to add a PEM-encoded clientRootCA
func (gServer *GRPCServer) appendClientRootCA(clientRoot []byte) error {

Expand All @@ -244,7 +230,7 @@ func (gServer *GRPCServer) appendClientRootCA(clientRoot []byte) error {

for i, cert := range certs {
//first add to the ClientCAs
gServer.tlsConfig.ClientCAs.AddCert(cert)
gServer.tls.AddClientRootCA(cert)
//add it to our clientRootCAs map using subject as key
gServer.clientRootCAs[subjects[i]] = cert
}
Expand All @@ -271,7 +257,7 @@ func (gServer *GRPCServer) RemoveClientRootCAs(clientRoots [][]byte) error {
}

//replace the current ClientCAs pool
gServer.tlsConfig.ClientCAs = certPool
gServer.tls.SetClientCAs(certPool)
return nil
}

Expand Down Expand Up @@ -330,6 +316,6 @@ func (gServer *GRPCServer) SetClientRootCAs(clientRoots [][]byte) error {
//replace the internal map
gServer.clientRootCAs = clientRootCAs
//replace the current ClientCAs pool
gServer.tlsConfig.ClientCAs = certPool
gServer.tls.SetClientCAs(certPool)
return nil
}
Loading