Skip to content

Commit

Permalink
manager: remove expiration judgement and reload certs periodically (#244
Browse files Browse the repository at this point in the history
)
  • Loading branch information
djshow832 authored Mar 9, 2023
1 parent 5a00605 commit f05556d
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 70 deletions.
65 changes: 26 additions & 39 deletions lib/util/security/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,20 @@ import (
"go.uber.org/zap"
)

const (
// Recreate the auto certs one hour before it expires.
// It should be longer than defaultRetryInterval.
recreateAutoCertAdvance = 24 * time.Hour
)

var emptyCert = new(tls.Certificate)

type CertInfo struct {
cfg config.TLSConfig
ca atomic.Value
cert atomic.Value
expire atomic.Int64
server bool
cfg config.TLSConfig
ca atomic.Value
cert atomic.Value
autoCertExp atomic.Int64
server bool
}

func NewCert(lg *zap.Logger, cfg config.TLSConfig, server bool) (*CertInfo, *tls.Config, error) {
Expand All @@ -46,10 +52,7 @@ func NewCert(lg *zap.Logger, cfg config.TLSConfig, server bool) (*CertInfo, *tls
return ci, tlscfg, err
}

func (ci *CertInfo) Reload(lg *zap.Logger, n time.Time) error {
if n.Unix() <= ci.expire.Load() {
return nil
}
func (ci *CertInfo) Reload(lg *zap.Logger) error {
_, err := ci.reload(lg)
return err
}
Expand Down Expand Up @@ -129,11 +132,6 @@ func (ci *CertInfo) verifyPeerCertificate(rawCerts [][]byte, _ [][]*x509.Certifi
return err
}

func (ci *CertInfo) updateMinExpire(n int64) {
for o := ci.expire.Load(); o > n && !ci.expire.CAS(o, n); o = ci.expire.Load() {
}
}

func (ci *CertInfo) loadCA(pemCerts []byte) (*x509.CertPool, error) {
pool := x509.NewCertPool()
for len(pemCerts) > 0 {
Expand All @@ -151,9 +149,6 @@ func (ci *CertInfo) loadCA(pemCerts []byte) (*x509.CertPool, error) {
if err != nil {
continue
}

ci.updateMinExpire(cert.NotAfter.Unix())

pool.AddCert(cert)
}
return pool, nil
Expand Down Expand Up @@ -181,13 +176,17 @@ func (ci *CertInfo) buildServerConfig(lg *zap.Logger) (*tls.Config, error) {
var certPEM, keyPEM []byte
var err error
if autoCerts {
dur, err := time.ParseDuration(ci.cfg.AutoExpireDuration)
if err != nil {
dur = DefaultCertExpiration
}
certPEM, keyPEM, _, err = createTempTLS(ci.cfg.RSAKeySize, dur)
if err != nil {
return nil, err
now := time.Now()
if time.Unix(ci.autoCertExp.Load(), 0).Before(now) {
dur, err := time.ParseDuration(ci.cfg.AutoExpireDuration)
if err != nil {
dur = DefaultCertExpiration
}
ci.autoCertExp.Store(now.Add(DefaultCertExpiration - recreateAutoCertAdvance).Unix())
certPEM, keyPEM, _, err = createTempTLS(ci.cfg.RSAKeySize, dur)
if err != nil {
return nil, err
}
}
} else {
certPEM, err = os.ReadFile(ci.cfg.Cert)
Expand All @@ -200,18 +199,13 @@ func (ci *CertInfo) buildServerConfig(lg *zap.Logger) (*tls.Config, error) {
}
}

cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
return nil, errors.WithStack(err)
}
for _, c := range cert.Certificate {
cp, err := x509.ParseCertificate(c)
if certPEM != nil {
cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
return nil, errors.WithStack(err)
}
ci.updateMinExpire(cp.NotAfter.Unix())
ci.cert.Store(&cert)
}
ci.cert.Store(&cert)

if !ci.cfg.HasCA() {
lg.Info("no CA, server will not authenticate clients (connection is still secured)")
Expand Down Expand Up @@ -283,13 +277,6 @@ func (ci *CertInfo) buildClientConfig(lg *zap.Logger) (*tls.Config, error) {
if err != nil {
return nil, errors.WithStack(err)
}
for _, c := range cert.Certificate {
cp, err := x509.ParseCertificate(c)
if err != nil {
return nil, errors.WithStack(err)
}
ci.updateMinExpire(cp.NotAfter.Unix())
}
ci.cert.Store(&cert)

return tcfg, nil
Expand Down
64 changes: 64 additions & 0 deletions lib/util/security/cert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package security

import (
"crypto/tls"
"crypto/x509"
"path/filepath"
"testing"
"time"
Expand Down Expand Up @@ -216,3 +217,66 @@ func TestCertServer(t *testing.T) {
}
}
}

func TestReload(t *testing.T) {
lg := logger.CreateLoggerForTest(t)
tmpdir := t.TempDir()
certPath := filepath.Join(tmpdir, "cert")
keyPath := filepath.Join(tmpdir, "key")
caPath := filepath.Join(tmpdir, "ca")
cfg := config.TLSConfig{
CA: caPath,
Cert: certPath,
Key: keyPath,
}

// Create a cert and record the expiration.
require.NoError(t, CreateTLSCertificates(lg, certPath, keyPath, caPath, 0, time.Hour))
ci, tcfg, err := NewCert(lg, cfg, true)
require.NoError(t, err)
require.NotNil(t, tcfg)
expire1 := getExpireTime(t, ci)

// Replace the cert and then reload. Check that the expiration is different.
err = CreateTLSCertificates(lg, certPath, keyPath, caPath, 0, 2*time.Hour)
require.NoError(t, err)
require.NoError(t, ci.Reload(lg))
expire2 := getExpireTime(t, ci)
require.NotEqual(t, expire1, expire2)
}

func TestAutoCerts(t *testing.T) {
lg := logger.CreateLoggerForTest(t)
cfg := config.TLSConfig{
AutoCerts: true,
}

// Create an auto cert.
ci, tcfg, err := NewCert(lg, cfg, true)
require.NoError(t, err)
require.NotNil(t, tcfg)
cert1 := ci.cert.Load().(*tls.Certificate)
expire1 := getExpireTime(t, ci)
require.True(t, ci.autoCertExp.Load() < expire1.Unix())

// The cert will not be recreated now.
ci.cfg.AutoExpireDuration = (DefaultCertExpiration - time.Hour).String()
require.NoError(t, ci.Reload(lg))
cert2 := ci.cert.Load().(*tls.Certificate)
require.Equal(t, cert1, cert2)
expire2 := getExpireTime(t, ci)
require.Equal(t, expire1, expire2)

// The cert will be recreated when it almost expires.
ci.autoCertExp.Store(time.Now().Add(-time.Minute).Unix())
require.NoError(t, ci.Reload(lg))
expire3 := getExpireTime(t, ci)
require.NotEqual(t, expire1, expire3)
}

func getExpireTime(t *testing.T, ci *CertInfo) time.Time {
cert := ci.cert.Load().(*tls.Certificate)
cp, err := x509.ParseCertificate(cert.Certificate[0])
require.NoError(t, err)
return cp.NotAfter
}
13 changes: 6 additions & 7 deletions pkg/manager/cert/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ func (cm *CertManager) SQLTLS() *tls.Config {

// The proxy is supposed to be always online, so it should reload certs automatically,
// rather than reloading it by restarting the proxy.
// The proxy checks expiration time periodically and reloads certs in advance. If reloading
// fails or the cert is not replaced, it will retry in the next round.
// The proxy periodically reloads certs. If it fails, we will retry in the next round.
func (cm *CertManager) reloadLoop(ctx context.Context) {
cm.wg.Run(func() {
for {
Expand All @@ -120,19 +119,19 @@ func (cm *CertManager) reloadLoop(ctx context.Context) {
})
}

// If any error happens, we still continue and use the old cert.
func (cm *CertManager) reload() {
now := time.Now()
errs := make([]error, 0, 4)
if err := cm.serverTLS.Reload(cm.logger, now); err != nil {
if err := cm.serverTLS.Reload(cm.logger); err != nil {
errs = append(errs, err)
}
if err := cm.peerTLS.Reload(cm.logger, now); err != nil {
if err := cm.peerTLS.Reload(cm.logger); err != nil {
errs = append(errs, err)
}
if err := cm.clusterTLS.Reload(cm.logger, now); err != nil {
if err := cm.clusterTLS.Reload(cm.logger); err != nil {
errs = append(errs, err)
}
if err := cm.sqlTLS.Reload(cm.logger, now); err != nil {
if err := cm.sqlTLS.Reload(cm.logger); err != nil {
errs = append(errs, err)
}
err := errors.Collect(errors.New("loading certs"), errs...)
Expand Down
79 changes: 55 additions & 24 deletions pkg/manager/cert/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"net"
"os"
"path/filepath"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -309,34 +310,64 @@ func TestRotate(t *testing.T) {
tc.reload(t)
}

timer := time.NewTimer(time.Second)
outer:
for {
select {
case <-timer.C:
t.Fatal("timeout on reloading")
case <-time.After(150 * time.Millisecond):
clientErr, serverErr := connectWithTLS(ctls, stls)
errmsg := fmt.Sprintf("client: %+v\nserver: %+v\n", clientErr, serverErr)
if tc.relErrCli != "" {
require.ErrorContains(t, clientErr, tc.relErrCli, errmsg)
require.Error(t, serverErr)
break outer
}
if tc.relErrSrv != "" {
require.ErrorContains(t, serverErr, tc.relErrSrv, errmsg)
require.Error(t, clientErr)
break outer
time.Sleep(150 * time.Millisecond)
require.Eventually(t, func() bool {
clientErr, serverErr := connectWithTLS(ctls, stls)
if tc.relErrCli != "" {
if !strings.Contains(clientErr.Error(), tc.relErrCli) || serverErr == nil {
t.Logf("clientErr: %+v, serverErr: %+v\n", clientErr, serverErr)
return false
}
if tc.relErrCli == "" && tc.relErrSrv == "" {
if clientErr == nil && serverErr == nil {
break outer
}
}
if tc.relErrSrv != "" {
if !strings.Contains(serverErr.Error(), tc.relErrSrv) || clientErr == nil {
t.Logf("clientErr: %+v, serverErr: %+v\n", clientErr, serverErr)
return false
}
}
}

if tc.relErrCli == "" && tc.relErrSrv == "" {
return clientErr == nil && serverErr == nil
}
return true
}, time.Second, 100*time.Millisecond)
certMgr.Close()
}
}

func TestBidirectional(t *testing.T) {
tmpdir := t.TempDir()
lg := logger.CreateLoggerForTest(t)
caPath1 := filepath.Join(tmpdir, "c1", "ca")
keyPath1 := filepath.Join(tmpdir, "c1", "key")
certPath1 := filepath.Join(tmpdir, "c1", "cert")
caPath2 := filepath.Join(tmpdir, "c2", "ca")
keyPath2 := filepath.Join(tmpdir, "c2", "key")
certPath2 := filepath.Join(tmpdir, "c2", "cert")

require.NoError(t, security.CreateTLSCertificates(lg, certPath1, keyPath1, caPath1, 0, security.DefaultCertExpiration))
require.NoError(t, security.CreateTLSCertificates(lg, certPath2, keyPath2, caPath2, 0, security.DefaultCertExpiration))

cfg := &config.Config{
Workdir: tmpdir,
Security: config.Security{
ServerTLS: config.TLSConfig{
Cert: certPath1,
Key: keyPath1,
CA: caPath2,
},
SQLTLS: config.TLSConfig{
CA: caPath1,
Key: keyPath2,
Cert: certPath2,
},
},
}

certMgr := NewCertManager()
require.NoError(t, certMgr.Init(cfg, lg))
stls := certMgr.ServerTLS()
ctls := certMgr.SQLTLS()
clientErr, serverErr := connectWithTLS(ctls, stls)
require.NoError(t, clientErr)
require.NoError(t, serverErr)
}

0 comments on commit f05556d

Please sign in to comment.