-
Notifications
You must be signed in to change notification settings - Fork 506
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FAB-9601] Move cert pool wrapper into its own package
Change-Id: I8e549dc957454bb15692d9285d3949c0f1b8c815 Signed-off-by: Divyank Katira <Divyank.Katira@securekey.com>
- Loading branch information
Showing
4 changed files
with
265 additions
and
204 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
/* | ||
Copyright SecureKey Technologies Inc. All Rights Reserved. | ||
SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package tls | ||
|
||
import ( | ||
"crypto/x509" | ||
"sync" | ||
|
||
"github.com/hyperledger/fabric-sdk-go/pkg/common/logging" | ||
) | ||
|
||
var logger = logging.NewLogger("fabsdk/core") | ||
|
||
// CertPool is a thread safe wrapper around the x509 standard library | ||
// cert pool implementation. | ||
type CertPool interface { | ||
// Get returns the cert pool, optionally adding the provided certs | ||
Get(certs ...*x509.Certificate) (*x509.CertPool, error) | ||
} | ||
|
||
// certPool is a thread safe wrapper around the x509 standard library | ||
// cert pool implementation. | ||
// It optionally allows loading the system trust store. | ||
type certPool struct { | ||
useSystemCertPool bool | ||
certs []*x509.Certificate | ||
certPool *x509.CertPool | ||
certsByName map[string][]int | ||
lock sync.RWMutex | ||
} | ||
|
||
// NewCertPool new CertPool implementation | ||
func NewCertPool(useSystemCertPool bool) CertPool { | ||
return &certPool{ | ||
useSystemCertPool: useSystemCertPool, | ||
certsByName: make(map[string][]int), | ||
certPool: x509.NewCertPool(), | ||
} | ||
} | ||
|
||
func (c *certPool) Get(certs ...*x509.Certificate) (*x509.CertPool, error) { | ||
c.lock.RLock() | ||
if len(certs) == 0 || c.containsCerts(certs...) { | ||
defer c.lock.RUnlock() | ||
return c.certPool, nil | ||
} | ||
c.lock.RUnlock() | ||
|
||
// We have a cert we have not encountered before, recreate the cert pool | ||
certPool, err := c.loadSystemCertPool() | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
c.lock.Lock() | ||
defer c.lock.Unlock() | ||
|
||
//add certs to SDK cert list | ||
for _, newCert := range certs { | ||
c.addCert(newCert) | ||
} | ||
//add all certs to cert pool | ||
for _, cert := range c.certs { | ||
certPool.AddCert(cert) | ||
} | ||
c.certPool = certPool | ||
|
||
return c.certPool, nil | ||
} | ||
|
||
func (c *certPool) addCert(newCert *x509.Certificate) { | ||
if newCert != nil && !c.containsCert(newCert) { | ||
n := len(c.certs) | ||
// Store cert | ||
c.certs = append(c.certs, newCert) | ||
// Store cert name index | ||
name := string(newCert.RawSubject) | ||
c.certsByName[name] = append(c.certsByName[name], n) | ||
} | ||
} | ||
|
||
func (c *certPool) containsCert(newCert *x509.Certificate) bool { | ||
possibilities := c.certsByName[string(newCert.RawSubject)] | ||
for _, p := range possibilities { | ||
if c.certs[p].Equal(newCert) { | ||
return true | ||
} | ||
} | ||
|
||
return false | ||
} | ||
|
||
func (c *certPool) containsCerts(certs ...*x509.Certificate) bool { | ||
for _, cert := range certs { | ||
if cert != nil && !c.containsCert(cert) { | ||
return false | ||
} | ||
} | ||
return true | ||
} | ||
|
||
func (c *certPool) loadSystemCertPool() (*x509.CertPool, error) { | ||
if !c.useSystemCertPool { | ||
return x509.NewCertPool(), nil | ||
} | ||
systemCertPool, err := x509.SystemCertPool() | ||
if err != nil { | ||
return nil, err | ||
} | ||
logger.Debugf("Loaded system cert pool of size: %d", len(systemCertPool.Subjects())) | ||
|
||
return systemCertPool, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
/* | ||
Copyright SecureKey Technologies Inc. All Rights Reserved. | ||
SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package tls | ||
|
||
import ( | ||
"crypto/x509" | ||
"strconv" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
var goodCert = &x509.Certificate{ | ||
RawSubject: []byte("Good header"), | ||
Raw: []byte("Good cert"), | ||
} | ||
|
||
func TestTLSCAConfig(t *testing.T) { | ||
tlsCertPool := NewCertPool(true).(*certPool) | ||
_, err := tlsCertPool.Get(goodCert) | ||
require.NoError(t, err) | ||
assert.Equal(t, true, tlsCertPool.useSystemCertPool) | ||
assert.NotNil(t, tlsCertPool.certPool) | ||
assert.NotNil(t, tlsCertPool.certsByName) | ||
|
||
originalLength := len(tlsCertPool.certs) | ||
//Try again with same cert | ||
_, err = tlsCertPool.Get(goodCert) | ||
assert.NoError(t, err, "TLS CA cert pool fetch failed") | ||
assert.False(t, len(tlsCertPool.certs) > originalLength, "number of certs in cert list shouldn't accept duplicates") | ||
|
||
// Test with system cert pool disabled | ||
tlsCertPool = NewCertPool(false).(*certPool) | ||
_, err = tlsCertPool.Get(goodCert) | ||
require.NoError(t, err) | ||
assert.Len(t, tlsCertPool.certs, 1) | ||
assert.Len(t, tlsCertPool.certPool.Subjects(), 1) | ||
} | ||
|
||
func TestTLSCAPoolManyCerts(t *testing.T) { | ||
size := 50 | ||
|
||
tlsCertPool := NewCertPool(true).(*certPool) | ||
_, err := tlsCertPool.Get(goodCert) | ||
require.NoError(t, err) | ||
|
||
pool, err := tlsCertPool.Get() | ||
assert.NoError(t, err) | ||
originalLen := len(pool.Subjects()) | ||
|
||
certs := createNCerts(size) | ||
pool, err = tlsCertPool.Get(certs[0]) | ||
assert.NoError(t, err) | ||
assert.Len(t, pool.Subjects(), originalLen+1) | ||
|
||
pool, err = tlsCertPool.Get(certs...) | ||
assert.NoError(t, err) | ||
assert.Len(t, pool.Subjects(), originalLen+size) | ||
} | ||
|
||
func TestConcurrent(t *testing.T) { | ||
concurrency := 1000 | ||
certs := createNCerts(concurrency) | ||
|
||
tlsCertPool := NewCertPool(false).(*certPool) | ||
|
||
writeDone := make(chan bool) | ||
readDone := make(chan bool) | ||
|
||
for i := 0; i < concurrency; i++ { | ||
go func(c *x509.Certificate) { | ||
_, err := tlsCertPool.Get(c) | ||
assert.NoError(t, err) | ||
writeDone <- true | ||
}(certs[i]) | ||
go func() { | ||
_, err := tlsCertPool.Get() | ||
assert.NoError(t, err) | ||
readDone <- true | ||
}() | ||
} | ||
|
||
for i := 0; i < concurrency; i++ { | ||
select { | ||
case b := <-writeDone: | ||
assert.True(t, b) | ||
case <-time.After(time.Second * 10): | ||
t.Fatalf("Timed out waiting for write %d", i) | ||
} | ||
|
||
select { | ||
case b := <-readDone: | ||
assert.True(t, b) | ||
case <-time.After(time.Second * 10): | ||
t.Fatalf("Timed out waiting for read %d", i) | ||
} | ||
} | ||
|
||
assert.Len(t, tlsCertPool.certs, concurrency) | ||
assert.Len(t, tlsCertPool.certPool.Subjects(), concurrency) | ||
} | ||
|
||
func createNCerts(n int) []*x509.Certificate { | ||
var certs []*x509.Certificate | ||
for i := 0; i < n; i++ { | ||
cert := &x509.Certificate{ | ||
RawSubject: []byte(strconv.Itoa(i)), | ||
Raw: []byte(strconv.Itoa(i)), | ||
} | ||
certs = append(certs, cert) | ||
} | ||
return certs | ||
} | ||
|
||
func BenchmarkTLSCertPool(b *testing.B) { | ||
tlsCertPool := NewCertPool(true).(*certPool) | ||
|
||
for n := 0; n < b.N; n++ { | ||
tlsCertPool.Get() | ||
} | ||
} | ||
|
||
func BenchmarkTLSCertPoolSameCert(b *testing.B) { | ||
tlsCertPool := NewCertPool(true).(*certPool) | ||
|
||
for n := 0; n < b.N; n++ { | ||
tlsCertPool.Get(goodCert) | ||
} | ||
} | ||
|
||
func BenchmarkTLSCertPoolDifferentCert(b *testing.B) { | ||
tlsCertPool := NewCertPool(true).(*certPool) | ||
certs := createNCerts(b.N) | ||
|
||
for n := 0; n < b.N; n++ { | ||
tlsCertPool.Get(certs[n]) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.