Skip to content

Commit

Permalink
[FAB-9601] Move cert pool wrapper into its own package
Browse files Browse the repository at this point in the history
Change-Id: I8e549dc957454bb15692d9285d3949c0f1b8c815
Signed-off-by: Divyank Katira <Divyank.Katira@securekey.com>
  • Loading branch information
d1vyank committed Apr 25, 2018
1 parent 830bdea commit ab35fb8
Show file tree
Hide file tree
Showing 4 changed files with 265 additions and 204 deletions.
117 changes: 117 additions & 0 deletions pkg/core/config/comm/tls/certpool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
Copyright SecureKey Technologies Inc. All Rights Reserved.
SPDX-License-Identifier: Apache-2.0
*/

package tls

import (
"crypto/x509"
"sync"

"github.com/hyperledger/fabric-sdk-go/pkg/common/logging"
)

var logger = logging.NewLogger("fabsdk/core")

// CertPool is a thread safe wrapper around the x509 standard library
// cert pool implementation.
type CertPool interface {
// Get returns the cert pool, optionally adding the provided certs
Get(certs ...*x509.Certificate) (*x509.CertPool, error)
}

// certPool is a thread safe wrapper around the x509 standard library
// cert pool implementation.
// It optionally allows loading the system trust store.
type certPool struct {
useSystemCertPool bool
certs []*x509.Certificate
certPool *x509.CertPool
certsByName map[string][]int
lock sync.RWMutex
}

// NewCertPool new CertPool implementation
func NewCertPool(useSystemCertPool bool) CertPool {
return &certPool{
useSystemCertPool: useSystemCertPool,
certsByName: make(map[string][]int),
certPool: x509.NewCertPool(),
}
}

func (c *certPool) Get(certs ...*x509.Certificate) (*x509.CertPool, error) {
c.lock.RLock()
if len(certs) == 0 || c.containsCerts(certs...) {
defer c.lock.RUnlock()
return c.certPool, nil
}
c.lock.RUnlock()

// We have a cert we have not encountered before, recreate the cert pool
certPool, err := c.loadSystemCertPool()
if err != nil {
return nil, err
}

c.lock.Lock()
defer c.lock.Unlock()

//add certs to SDK cert list
for _, newCert := range certs {
c.addCert(newCert)
}
//add all certs to cert pool
for _, cert := range c.certs {
certPool.AddCert(cert)
}
c.certPool = certPool

return c.certPool, nil
}

func (c *certPool) addCert(newCert *x509.Certificate) {
if newCert != nil && !c.containsCert(newCert) {
n := len(c.certs)
// Store cert
c.certs = append(c.certs, newCert)
// Store cert name index
name := string(newCert.RawSubject)
c.certsByName[name] = append(c.certsByName[name], n)
}
}

func (c *certPool) containsCert(newCert *x509.Certificate) bool {
possibilities := c.certsByName[string(newCert.RawSubject)]
for _, p := range possibilities {
if c.certs[p].Equal(newCert) {
return true
}
}

return false
}

func (c *certPool) containsCerts(certs ...*x509.Certificate) bool {
for _, cert := range certs {
if cert != nil && !c.containsCert(cert) {
return false
}
}
return true
}

func (c *certPool) loadSystemCertPool() (*x509.CertPool, error) {
if !c.useSystemCertPool {
return x509.NewCertPool(), nil
}
systemCertPool, err := x509.SystemCertPool()
if err != nil {
return nil, err
}
logger.Debugf("Loaded system cert pool of size: %d", len(systemCertPool.Subjects()))

return systemCertPool, nil
}
144 changes: 144 additions & 0 deletions pkg/core/config/comm/tls/certpool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
Copyright SecureKey Technologies Inc. All Rights Reserved.
SPDX-License-Identifier: Apache-2.0
*/

package tls

import (
"crypto/x509"
"strconv"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

var goodCert = &x509.Certificate{
RawSubject: []byte("Good header"),
Raw: []byte("Good cert"),
}

func TestTLSCAConfig(t *testing.T) {
tlsCertPool := NewCertPool(true).(*certPool)
_, err := tlsCertPool.Get(goodCert)
require.NoError(t, err)
assert.Equal(t, true, tlsCertPool.useSystemCertPool)
assert.NotNil(t, tlsCertPool.certPool)
assert.NotNil(t, tlsCertPool.certsByName)

originalLength := len(tlsCertPool.certs)
//Try again with same cert
_, err = tlsCertPool.Get(goodCert)
assert.NoError(t, err, "TLS CA cert pool fetch failed")
assert.False(t, len(tlsCertPool.certs) > originalLength, "number of certs in cert list shouldn't accept duplicates")

// Test with system cert pool disabled
tlsCertPool = NewCertPool(false).(*certPool)
_, err = tlsCertPool.Get(goodCert)
require.NoError(t, err)
assert.Len(t, tlsCertPool.certs, 1)
assert.Len(t, tlsCertPool.certPool.Subjects(), 1)
}

func TestTLSCAPoolManyCerts(t *testing.T) {
size := 50

tlsCertPool := NewCertPool(true).(*certPool)
_, err := tlsCertPool.Get(goodCert)
require.NoError(t, err)

pool, err := tlsCertPool.Get()
assert.NoError(t, err)
originalLen := len(pool.Subjects())

certs := createNCerts(size)
pool, err = tlsCertPool.Get(certs[0])
assert.NoError(t, err)
assert.Len(t, pool.Subjects(), originalLen+1)

pool, err = tlsCertPool.Get(certs...)
assert.NoError(t, err)
assert.Len(t, pool.Subjects(), originalLen+size)
}

func TestConcurrent(t *testing.T) {
concurrency := 1000
certs := createNCerts(concurrency)

tlsCertPool := NewCertPool(false).(*certPool)

writeDone := make(chan bool)
readDone := make(chan bool)

for i := 0; i < concurrency; i++ {
go func(c *x509.Certificate) {
_, err := tlsCertPool.Get(c)
assert.NoError(t, err)
writeDone <- true
}(certs[i])
go func() {
_, err := tlsCertPool.Get()
assert.NoError(t, err)
readDone <- true
}()
}

for i := 0; i < concurrency; i++ {
select {
case b := <-writeDone:
assert.True(t, b)
case <-time.After(time.Second * 10):
t.Fatalf("Timed out waiting for write %d", i)
}

select {
case b := <-readDone:
assert.True(t, b)
case <-time.After(time.Second * 10):
t.Fatalf("Timed out waiting for read %d", i)
}
}

assert.Len(t, tlsCertPool.certs, concurrency)
assert.Len(t, tlsCertPool.certPool.Subjects(), concurrency)
}

func createNCerts(n int) []*x509.Certificate {
var certs []*x509.Certificate
for i := 0; i < n; i++ {
cert := &x509.Certificate{
RawSubject: []byte(strconv.Itoa(i)),
Raw: []byte(strconv.Itoa(i)),
}
certs = append(certs, cert)
}
return certs
}

func BenchmarkTLSCertPool(b *testing.B) {
tlsCertPool := NewCertPool(true).(*certPool)

for n := 0; n < b.N; n++ {
tlsCertPool.Get()
}
}

func BenchmarkTLSCertPoolSameCert(b *testing.B) {
tlsCertPool := NewCertPool(true).(*certPool)

for n := 0; n < b.N; n++ {
tlsCertPool.Get(goodCert)
}
}

func BenchmarkTLSCertPoolDifferentCert(b *testing.B) {
tlsCertPool := NewCertPool(true).(*certPool)
certs := createNCerts(b.N)

for n := 0; n < b.N; n++ {
tlsCertPool.Get(certs[n])
}
}
82 changes: 4 additions & 78 deletions pkg/fab/endpointconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ import (
"sort"
"strconv"
"strings"
"sync"
"time"

"github.com/hyperledger/fabric-sdk-go/pkg/common/errors/status"
"github.com/hyperledger/fabric-sdk-go/pkg/common/logging"
"github.com/hyperledger/fabric-sdk-go/pkg/common/providers/core"
"github.com/hyperledger/fabric-sdk-go/pkg/common/providers/fab"
"github.com/hyperledger/fabric-sdk-go/pkg/common/providers/msp"
commtls "github.com/hyperledger/fabric-sdk-go/pkg/core/config/comm/tls"
"github.com/hyperledger/fabric-sdk-go/pkg/core/config/cryptoutil"
"github.com/hyperledger/fabric-sdk-go/pkg/core/config/endpoint"
"github.com/hyperledger/fabric-sdk-go/pkg/core/config/lookup"
Expand Down Expand Up @@ -60,8 +60,6 @@ const (
func ConfigFromBackend(coreBackend core.ConfigBackend) (fab.EndpointConfig, error) {
config := &EndpointConfig{
backend: lookup.New(coreBackend),
tlsCertsByName: make(map[string][]int),
tlsCertPool: x509.NewCertPool(),
peerMatchers: make(map[int]*regexp.Regexp),
ordererMatchers: make(map[int]*regexp.Regexp),
caMatchers: make(map[int]*regexp.Regexp),
Expand All @@ -72,6 +70,7 @@ func ConfigFromBackend(coreBackend core.ConfigBackend) (fab.EndpointConfig, erro
return nil, errors.WithMessage(err, "network configuration load failed")
}

config.tlsCertPool = commtls.NewCertPool(config.backend.GetBool("client.tlsCerts.systemCertPool"))
// preemptively add all TLS certs to cert pool as adding them at request time
// is expensive
certs, err := config.loadTLSCerts()
Expand All @@ -94,16 +93,13 @@ func ConfigFromBackend(coreBackend core.ConfigBackend) (fab.EndpointConfig, erro
// EndpointConfig represents the endpoint configuration for the client
type EndpointConfig struct {
backend *lookup.ConfigLookup
tlsCerts []*x509.Certificate
networkConfig *fab.NetworkConfig
tlsCertPool commtls.CertPool
networkConfigCached bool
tlsCertPool *x509.CertPool
peerMatchers map[int]*regexp.Regexp
ordererMatchers map[int]*regexp.Regexp
caMatchers map[int]*regexp.Regexp
channelMatchers map[int]*regexp.Regexp
tlsCertsByName map[string][]int
certPoolLock sync.RWMutex
}

// Timeout reads timeouts for the given timeout type, if type is not found in the config
Expand Down Expand Up @@ -464,33 +460,7 @@ func (c *EndpointConfig) ChannelOrderers(name string) ([]fab.OrdererConfig, erro
// TLSCACertPool returns the configured cert pool. If a certConfig
// is provided, the certficate is added to the pool
func (c *EndpointConfig) TLSCACertPool(certs ...*x509.Certificate) (*x509.CertPool, error) {
c.certPoolLock.RLock()
if len(certs) == 0 || c.containsCerts(certs...) {
defer c.certPoolLock.RUnlock()
return c.tlsCertPool, nil
}
c.certPoolLock.RUnlock()

// We have a cert we have not encountered before, recreate the cert pool
tlsCertPool, err := c.loadSystemCertPool()
if err != nil {
return nil, err
}

c.certPoolLock.Lock()
defer c.certPoolLock.Unlock()

//add certs to SDK cert list
for _, newCert := range certs {
c.addCert(newCert)
}
//add all certs to cert pool
for _, cert := range c.tlsCerts {
tlsCertPool.AddCert(cert)
}
c.tlsCertPool = tlsCertPool

return c.tlsCertPool, nil
return c.tlsCertPool.Get(certs...)
}

// EventServiceType returns the type of event service client to use
Expand Down Expand Up @@ -1107,50 +1077,6 @@ func (c *EndpointConfig) loadTLSCerts() ([]*x509.Certificate, error) {
return certs, nil
}

func (c *EndpointConfig) addCert(newCert *x509.Certificate) {
if newCert != nil && !c.containsCert(newCert) {
n := len(c.tlsCerts)
// Store cert
c.tlsCerts = append(c.tlsCerts, newCert)
// Store cert name index
name := string(newCert.RawSubject)
c.tlsCertsByName[name] = append(c.tlsCertsByName[name], n)
}
}

func (c *EndpointConfig) containsCert(newCert *x509.Certificate) bool {
possibilities := c.tlsCertsByName[string(newCert.RawSubject)]
for _, p := range possibilities {
if c.tlsCerts[p].Equal(newCert) {
return true
}
}

return false
}

func (c *EndpointConfig) containsCerts(certs ...*x509.Certificate) bool {
for _, cert := range certs {
if cert != nil && !c.containsCert(cert) {
return false
}
}
return true
}

func (c *EndpointConfig) loadSystemCertPool() (*x509.CertPool, error) {
if !c.backend.GetBool("client.tlsCerts.systemCertPool") {
return x509.NewCertPool(), nil
}
systemCertPool, err := x509.SystemCertPool()
if err != nil {
return nil, err
}
logger.Debugf("Loaded system cert pool of size: %d", len(systemCertPool.Subjects()))

return systemCertPool, nil
}

// Client returns the Client config
func (c *EndpointConfig) client() (*msp.ClientConfig, error) {
config, err := c.NetworkConfig()
Expand Down
Loading

0 comments on commit ab35fb8

Please sign in to comment.