diff --git a/pkg/core/config/comm/tls/certpool.go b/pkg/core/config/comm/tls/certpool.go new file mode 100644 index 0000000000..d7c31dfd46 --- /dev/null +++ b/pkg/core/config/comm/tls/certpool.go @@ -0,0 +1,117 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package tls + +import ( + "crypto/x509" + "sync" + + "github.com/hyperledger/fabric-sdk-go/pkg/common/logging" +) + +var logger = logging.NewLogger("fabsdk/core") + +// CertPool is a thread safe wrapper around the x509 standard library +// cert pool implementation. +type CertPool interface { + // Get returns the cert pool, optionally adding the provided certs + Get(certs ...*x509.Certificate) (*x509.CertPool, error) +} + +// certPool is a thread safe wrapper around the x509 standard library +// cert pool implementation. +// It optionally allows loading the system trust store. +type certPool struct { + useSystemCertPool bool + certs []*x509.Certificate + certPool *x509.CertPool + certsByName map[string][]int + lock sync.RWMutex +} + +// NewCertPool new CertPool implementation +func NewCertPool(useSystemCertPool bool) CertPool { + return &certPool{ + useSystemCertPool: useSystemCertPool, + certsByName: make(map[string][]int), + certPool: x509.NewCertPool(), + } +} + +func (c *certPool) Get(certs ...*x509.Certificate) (*x509.CertPool, error) { + c.lock.RLock() + if len(certs) == 0 || c.containsCerts(certs...) { + defer c.lock.RUnlock() + return c.certPool, nil + } + c.lock.RUnlock() + + // We have a cert we have not encountered before, recreate the cert pool + certPool, err := c.loadSystemCertPool() + if err != nil { + return nil, err + } + + c.lock.Lock() + defer c.lock.Unlock() + + //add certs to SDK cert list + for _, newCert := range certs { + c.addCert(newCert) + } + //add all certs to cert pool + for _, cert := range c.certs { + certPool.AddCert(cert) + } + c.certPool = certPool + + return c.certPool, nil +} + +func (c *certPool) addCert(newCert *x509.Certificate) { + if newCert != nil && !c.containsCert(newCert) { + n := len(c.certs) + // Store cert + c.certs = append(c.certs, newCert) + // Store cert name index + name := string(newCert.RawSubject) + c.certsByName[name] = append(c.certsByName[name], n) + } +} + +func (c *certPool) containsCert(newCert *x509.Certificate) bool { + possibilities := c.certsByName[string(newCert.RawSubject)] + for _, p := range possibilities { + if c.certs[p].Equal(newCert) { + return true + } + } + + return false +} + +func (c *certPool) containsCerts(certs ...*x509.Certificate) bool { + for _, cert := range certs { + if cert != nil && !c.containsCert(cert) { + return false + } + } + return true +} + +func (c *certPool) loadSystemCertPool() (*x509.CertPool, error) { + if !c.useSystemCertPool { + return x509.NewCertPool(), nil + } + systemCertPool, err := x509.SystemCertPool() + if err != nil { + return nil, err + } + logger.Debugf("Loaded system cert pool of size: %d", len(systemCertPool.Subjects())) + + return systemCertPool, nil +} diff --git a/pkg/core/config/comm/tls/certpool_test.go b/pkg/core/config/comm/tls/certpool_test.go new file mode 100644 index 0000000000..2798c61835 --- /dev/null +++ b/pkg/core/config/comm/tls/certpool_test.go @@ -0,0 +1,144 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package tls + +import ( + "crypto/x509" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var goodCert = &x509.Certificate{ + RawSubject: []byte("Good header"), + Raw: []byte("Good cert"), +} + +func TestTLSCAConfig(t *testing.T) { + tlsCertPool := NewCertPool(true).(*certPool) + _, err := tlsCertPool.Get(goodCert) + require.NoError(t, err) + assert.Equal(t, true, tlsCertPool.useSystemCertPool) + assert.NotNil(t, tlsCertPool.certPool) + assert.NotNil(t, tlsCertPool.certsByName) + + originalLength := len(tlsCertPool.certs) + //Try again with same cert + _, err = tlsCertPool.Get(goodCert) + assert.NoError(t, err, "TLS CA cert pool fetch failed") + assert.False(t, len(tlsCertPool.certs) > originalLength, "number of certs in cert list shouldn't accept duplicates") + + // Test with system cert pool disabled + tlsCertPool = NewCertPool(false).(*certPool) + _, err = tlsCertPool.Get(goodCert) + require.NoError(t, err) + assert.Len(t, tlsCertPool.certs, 1) + assert.Len(t, tlsCertPool.certPool.Subjects(), 1) +} + +func TestTLSCAPoolManyCerts(t *testing.T) { + size := 50 + + tlsCertPool := NewCertPool(true).(*certPool) + _, err := tlsCertPool.Get(goodCert) + require.NoError(t, err) + + pool, err := tlsCertPool.Get() + assert.NoError(t, err) + originalLen := len(pool.Subjects()) + + certs := createNCerts(size) + pool, err = tlsCertPool.Get(certs[0]) + assert.NoError(t, err) + assert.Len(t, pool.Subjects(), originalLen+1) + + pool, err = tlsCertPool.Get(certs...) + assert.NoError(t, err) + assert.Len(t, pool.Subjects(), originalLen+size) +} + +func TestConcurrent(t *testing.T) { + concurrency := 1000 + certs := createNCerts(concurrency) + + tlsCertPool := NewCertPool(false).(*certPool) + + writeDone := make(chan bool) + readDone := make(chan bool) + + for i := 0; i < concurrency; i++ { + go func(c *x509.Certificate) { + _, err := tlsCertPool.Get(c) + assert.NoError(t, err) + writeDone <- true + }(certs[i]) + go func() { + _, err := tlsCertPool.Get() + assert.NoError(t, err) + readDone <- true + }() + } + + for i := 0; i < concurrency; i++ { + select { + case b := <-writeDone: + assert.True(t, b) + case <-time.After(time.Second * 10): + t.Fatalf("Timed out waiting for write %d", i) + } + + select { + case b := <-readDone: + assert.True(t, b) + case <-time.After(time.Second * 10): + t.Fatalf("Timed out waiting for read %d", i) + } + } + + assert.Len(t, tlsCertPool.certs, concurrency) + assert.Len(t, tlsCertPool.certPool.Subjects(), concurrency) +} + +func createNCerts(n int) []*x509.Certificate { + var certs []*x509.Certificate + for i := 0; i < n; i++ { + cert := &x509.Certificate{ + RawSubject: []byte(strconv.Itoa(i)), + Raw: []byte(strconv.Itoa(i)), + } + certs = append(certs, cert) + } + return certs +} + +func BenchmarkTLSCertPool(b *testing.B) { + tlsCertPool := NewCertPool(true).(*certPool) + + for n := 0; n < b.N; n++ { + tlsCertPool.Get() + } +} + +func BenchmarkTLSCertPoolSameCert(b *testing.B) { + tlsCertPool := NewCertPool(true).(*certPool) + + for n := 0; n < b.N; n++ { + tlsCertPool.Get(goodCert) + } +} + +func BenchmarkTLSCertPoolDifferentCert(b *testing.B) { + tlsCertPool := NewCertPool(true).(*certPool) + certs := createNCerts(b.N) + + for n := 0; n < b.N; n++ { + tlsCertPool.Get(certs[n]) + } +} diff --git a/pkg/fab/endpointconfig.go b/pkg/fab/endpointconfig.go index d36091d0d3..0a8ae4b7c7 100644 --- a/pkg/fab/endpointconfig.go +++ b/pkg/fab/endpointconfig.go @@ -15,7 +15,6 @@ import ( "sort" "strconv" "strings" - "sync" "time" "github.com/hyperledger/fabric-sdk-go/pkg/common/errors/status" @@ -23,6 +22,7 @@ import ( "github.com/hyperledger/fabric-sdk-go/pkg/common/providers/core" "github.com/hyperledger/fabric-sdk-go/pkg/common/providers/fab" "github.com/hyperledger/fabric-sdk-go/pkg/common/providers/msp" + commtls "github.com/hyperledger/fabric-sdk-go/pkg/core/config/comm/tls" "github.com/hyperledger/fabric-sdk-go/pkg/core/config/cryptoutil" "github.com/hyperledger/fabric-sdk-go/pkg/core/config/endpoint" "github.com/hyperledger/fabric-sdk-go/pkg/core/config/lookup" @@ -60,8 +60,6 @@ const ( func ConfigFromBackend(coreBackend core.ConfigBackend) (fab.EndpointConfig, error) { config := &EndpointConfig{ backend: lookup.New(coreBackend), - tlsCertsByName: make(map[string][]int), - tlsCertPool: x509.NewCertPool(), peerMatchers: make(map[int]*regexp.Regexp), ordererMatchers: make(map[int]*regexp.Regexp), caMatchers: make(map[int]*regexp.Regexp), @@ -72,6 +70,7 @@ func ConfigFromBackend(coreBackend core.ConfigBackend) (fab.EndpointConfig, erro return nil, errors.WithMessage(err, "network configuration load failed") } + config.tlsCertPool = commtls.NewCertPool(config.backend.GetBool("client.tlsCerts.systemCertPool")) // preemptively add all TLS certs to cert pool as adding them at request time // is expensive certs, err := config.loadTLSCerts() @@ -94,16 +93,13 @@ func ConfigFromBackend(coreBackend core.ConfigBackend) (fab.EndpointConfig, erro // EndpointConfig represents the endpoint configuration for the client type EndpointConfig struct { backend *lookup.ConfigLookup - tlsCerts []*x509.Certificate networkConfig *fab.NetworkConfig + tlsCertPool commtls.CertPool networkConfigCached bool - tlsCertPool *x509.CertPool peerMatchers map[int]*regexp.Regexp ordererMatchers map[int]*regexp.Regexp caMatchers map[int]*regexp.Regexp channelMatchers map[int]*regexp.Regexp - tlsCertsByName map[string][]int - certPoolLock sync.RWMutex } // Timeout reads timeouts for the given timeout type, if type is not found in the config @@ -464,33 +460,7 @@ func (c *EndpointConfig) ChannelOrderers(name string) ([]fab.OrdererConfig, erro // TLSCACertPool returns the configured cert pool. If a certConfig // is provided, the certficate is added to the pool func (c *EndpointConfig) TLSCACertPool(certs ...*x509.Certificate) (*x509.CertPool, error) { - c.certPoolLock.RLock() - if len(certs) == 0 || c.containsCerts(certs...) { - defer c.certPoolLock.RUnlock() - return c.tlsCertPool, nil - } - c.certPoolLock.RUnlock() - - // We have a cert we have not encountered before, recreate the cert pool - tlsCertPool, err := c.loadSystemCertPool() - if err != nil { - return nil, err - } - - c.certPoolLock.Lock() - defer c.certPoolLock.Unlock() - - //add certs to SDK cert list - for _, newCert := range certs { - c.addCert(newCert) - } - //add all certs to cert pool - for _, cert := range c.tlsCerts { - tlsCertPool.AddCert(cert) - } - c.tlsCertPool = tlsCertPool - - return c.tlsCertPool, nil + return c.tlsCertPool.Get(certs...) } // EventServiceType returns the type of event service client to use @@ -1107,50 +1077,6 @@ func (c *EndpointConfig) loadTLSCerts() ([]*x509.Certificate, error) { return certs, nil } -func (c *EndpointConfig) addCert(newCert *x509.Certificate) { - if newCert != nil && !c.containsCert(newCert) { - n := len(c.tlsCerts) - // Store cert - c.tlsCerts = append(c.tlsCerts, newCert) - // Store cert name index - name := string(newCert.RawSubject) - c.tlsCertsByName[name] = append(c.tlsCertsByName[name], n) - } -} - -func (c *EndpointConfig) containsCert(newCert *x509.Certificate) bool { - possibilities := c.tlsCertsByName[string(newCert.RawSubject)] - for _, p := range possibilities { - if c.tlsCerts[p].Equal(newCert) { - return true - } - } - - return false -} - -func (c *EndpointConfig) containsCerts(certs ...*x509.Certificate) bool { - for _, cert := range certs { - if cert != nil && !c.containsCert(cert) { - return false - } - } - return true -} - -func (c *EndpointConfig) loadSystemCertPool() (*x509.CertPool, error) { - if !c.backend.GetBool("client.tlsCerts.systemCertPool") { - return x509.NewCertPool(), nil - } - systemCertPool, err := x509.SystemCertPool() - if err != nil { - return nil, err - } - logger.Debugf("Loaded system cert pool of size: %d", len(systemCertPool.Subjects())) - - return systemCertPool, nil -} - // Client returns the Client config func (c *EndpointConfig) client() (*msp.ClientConfig, error) { config, err := c.NetworkConfig() diff --git a/pkg/fab/endpointconfig_test.go b/pkg/fab/endpointconfig_test.go index 5df5b20ad2..7f84d49d9f 100644 --- a/pkg/fab/endpointconfig_test.go +++ b/pkg/fab/endpointconfig_test.go @@ -8,8 +8,6 @@ package fab import ( "crypto/tls" - "crypto/x509" - "strconv" "testing" "os" @@ -27,11 +25,9 @@ import ( "github.com/hyperledger/fabric-sdk-go/pkg/common/providers/core" "github.com/hyperledger/fabric-sdk-go/pkg/common/providers/fab" "github.com/hyperledger/fabric-sdk-go/pkg/core/config" - "github.com/hyperledger/fabric-sdk-go/pkg/core/config/endpoint" "github.com/hyperledger/fabric-sdk-go/pkg/core/mocks" "github.com/hyperledger/fabric-sdk-go/pkg/util/pathvar" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) const ( @@ -135,88 +131,6 @@ func checkCAConfigFailsByNetworkConfig(sampleEndpointConfig *EndpointConfig, t * } } -func TestTLSCAConfig(t *testing.T) { - //Test TLSCA Cert Pool (Positive test case) - certConfig := endpoint.TLSConfig{Path: pathvar.Subst(certPath)} - - cert, err := certConfig.TLSCert() - if err != nil { - t.Fatalf("Failed to get TLS CA Cert, reason: %v", err) - } - - configBackend1, err := ConfigFromBackend(configBackend) - if err != nil { - t.Fatalf("Failed to get endpoint config, reason: %v", err) - } - - endpointConfig := configBackend1.(*EndpointConfig) - - _, err = endpointConfig.TLSCACertPool(cert) - if err != nil { - t.Fatalf("TLS CA cert pool fetch failed, reason: %v", err) - } - - originalLength := len(endpointConfig.tlsCerts) - //Try again with same cert - _, err = endpointConfig.TLSCACertPool(cert) - assert.NoError(t, err, "TLS CA cert pool fetch failed") - assert.False(t, len(endpointConfig.tlsCerts) > originalLength, "number of certs in cert list shouldn't accept duplicates") - - //Test TLSCA Cert Pool (Negative test case) - badCertConfig := endpoint.TLSConfig{Path: "some random invalid path"} - badCert, err := badCertConfig.TLSCert() - if err == nil { - t.Fatalf("TLS CA cert pool was supposed to fail") - } - - _, err = endpointConfig.TLSCACertPool(badCert) - - if err != nil { - t.Fatalf(err.Error()) - } - - keyConfig := endpoint.TLSConfig{Path: keyPath} - - key, err := keyConfig.TLSCert() - - if err == nil { - t.Fatalf("TLS CA cert pool was supposed to fail when provided with wrong cert file") - } - - _, err = endpointConfig.TLSCACertPool(key) - if err != nil { - t.Fatalf(err.Error()) - } -} - -func TestTLSCAPoolManyCerts(t *testing.T) { - size := 50 - var certs []*x509.Certificate - - configBackend, err := ConfigFromBackend(configBackend) - assert.NoError(t, err) - endpointConfig := configBackend.(*EndpointConfig) - - pool, err := endpointConfig.TLSCACertPool() - assert.NoError(t, err) - originalLen := len(pool.Subjects()) - - for i := 0; i < size; i++ { - cert := &x509.Certificate{ - RawSubject: []byte(strconv.Itoa(i)), - Raw: []byte(strconv.Itoa(i)), - } - certs = append(certs, cert) - } - pool, err = endpointConfig.TLSCACertPool(certs[0]) - assert.NoError(t, err) - assert.Len(t, pool.Subjects(), originalLen+1) - - pool, err = endpointConfig.TLSCACertPool(certs...) - assert.NoError(t, err) - assert.Len(t, pool.Subjects(), originalLen+size) -} - func TestTimeouts(t *testing.T) { customBackend := getCustomBackend() customBackend.KeyValueMap["client.peer.timeout.connection"] = "12s" @@ -1285,46 +1199,6 @@ func tamperPeerChannelConfig(backend *mocks.MockConfigBackend) { (channelsMap.(map[string]interface{}))[orgChannelID] = orgChannel } -func BenchmarkTLSCertPool(b *testing.B) { - customBackend := getCustomBackend() - customBackend.KeyValueMap["client.tlsCerts.systemCertPool"] = "true" - endpointConfig, err := ConfigFromBackend(customBackend) - require.NoError(b, err) - - for n := 0; n < b.N; n++ { - endpointConfig.TLSCACertPool() - } -} - -func BenchmarkTLSCertPoolSameCert(b *testing.B) { - customBackend := getCustomBackend() - customBackend.KeyValueMap["client.tlsCerts.systemCertPool"] = "true" - endpointConfig, err := ConfigFromBackend(customBackend) - require.NoError(b, err) - certConfig := endpoint.TLSConfig{Path: pathvar.Subst(certPath)} - cert, err := certConfig.TLSCert() - require.NoError(b, err) - - for n := 0; n < b.N; n++ { - endpointConfig.TLSCACertPool(cert) - } -} - -func BenchmarkTLSCertPoolDifferentCert(b *testing.B) { - customBackend := getCustomBackend() - customBackend.KeyValueMap["client.tlsCerts.systemCertPool"] = "true" - endpointConfig, err := ConfigFromBackend(customBackend) - require.NoError(b, err) - certConfig := endpoint.TLSConfig{Path: pathvar.Subst(certPath)} - cert, err := certConfig.TLSCert() - require.NoError(b, err) - - for n := 0; n < b.N; n++ { - cert.RawSubject = []byte(strconv.Itoa(n)) - endpointConfig.TLSCACertPool(cert) - } -} - func getMatcherConfig() core.ConfigBackend { cfgBackend, err := config.FromFile(configTestEntityMatchersFilePath)() if err != nil {