Skip to content

Commit

Permalink
manager: rotate ca (#115)
Browse files Browse the repository at this point in the history
  • Loading branch information
djshow832 authored Oct 21, 2022
1 parent 80ff910 commit d3971a8
Show file tree
Hide file tree
Showing 4 changed files with 273 additions and 65 deletions.
4 changes: 2 additions & 2 deletions lib/util/security/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import (

const DefaultCertExpiration = 10 * 365 * 24 * time.Hour

func createTLSConfigificates(logger *zap.Logger, certpath, keypath, capath string, rsaKeySize int, expiration time.Duration) error {
func createTLSCertificates(logger *zap.Logger, certpath, keypath, capath string, rsaKeySize int, expiration time.Duration) error {
logger = logger.With(zap.String("cert", certpath), zap.String("key", keypath), zap.String("ca", capath), zap.Int("rsaKeySize", rsaKeySize))

_, e1 := os.Stat(certpath)
Expand Down Expand Up @@ -93,7 +93,7 @@ func AutoTLS(logger *zap.Logger, scfg *config.TLSConfig, autoca bool, workdir, m
if autoca {
scfg.CA = filepath.Join(workdir, mod, "ca.pem")
}
if err := createTLSConfigificates(logger, scfg.Cert, scfg.Key, scfg.CA, keySize, DefaultCertExpiration); err != nil {
if err := createTLSCertificates(logger, scfg.Cert, scfg.Key, scfg.CA, keySize, DefaultCertExpiration); err != nil {
return errors.WithStack(err)
}
return nil
Expand Down
150 changes: 150 additions & 0 deletions pkg/manager/cert/cert_info.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// Copyright 2022 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package cert

import (
"crypto/tls"
"crypto/x509"
"sync/atomic"
"time"

"github.com/pingcap/TiProxy/lib/config"
"github.com/pingcap/TiProxy/lib/util/errors"
"github.com/pingcap/TiProxy/lib/util/security"
"go.uber.org/zap"
)

// Security configurations don't support dynamically updating now.
type certInfo struct {
cfg config.TLSConfig
tlsConfig atomic.Pointer[tls.Config]
autoCert bool
autoCertExp time.Time
isServer bool
}

func (ci *certInfo) buildTLSConfig(lg *zap.Logger) error {
builder := security.BuildClientTLSConfig
if ci.isServer {
builder = security.BuildServerTLSConfig
}
tlsConfig, err := builder(lg, ci.cfg)
if err == nil {
tlsConfig = ci.customizeTLSConfig(tlsConfig)
ci.tlsConfig.Store(tlsConfig)
}
return err
}

func (ci *certInfo) getCertificate() *tls.Certificate {
tlsConfig := ci.tlsConfig.Load()
if tlsConfig != nil && len(tlsConfig.Certificates) > 0 {
return &tlsConfig.Certificates[0]
}
return nil
}

func (ci *certInfo) getCAs() *x509.CertPool {
tlsConfig := ci.tlsConfig.Load()
if tlsConfig != nil {
if ci.isServer {
return tlsConfig.ClientCAs
}
return tlsConfig.RootCAs
}
return nil
}

func (ci *certInfo) verifyPeerCertificate(rawCerts [][]byte, _ [][]*x509.Certificate) error {
if len(rawCerts) == 0 {
return nil
}

certs := make([]*x509.Certificate, len(rawCerts))
for i, asn1Data := range rawCerts {
cert, err := x509.ParseCertificate(asn1Data)
if err != nil {
return errors.New("tls: failed to parse certificate from server: " + err.Error())
}
certs[i] = cert
}

latestConfig := ci.tlsConfig.Load()
t := latestConfig.Time
if t == nil {
t = time.Now
}
opts := x509.VerifyOptions{
Roots: ci.getCAs(),
CurrentTime: t(),
Intermediates: x509.NewCertPool(),
}
if ci.isServer {
opts.KeyUsages = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}
} else {
opts.DNSName = latestConfig.ServerName
}
for _, cert := range certs[1:] {
opts.Intermediates.AddCert(cert)
}
_, err := certs[0].Verify(opts)
return err
}

// Some methods to rotate server config:
// - For certs: customize GetCertificate / GetConfigForClient.
// - For CA: customize ClientAuth + VerifyPeerCertificate / GetConfigForClient
// Some methods to rotate client config:
// - For certs: customize GetClientCertificate
// - For CA: customize InsecureSkipVerify + VerifyPeerCertificate
func (ci *certInfo) customizeTLSConfig(tlsConfig *tls.Config) *tls.Config {
if tlsConfig == nil {
return nil
}
tlsConfig = tlsConfig.Clone()
if ci.isServer {
tlsConfig.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return ci.getCertificate(), nil
}
if tlsConfig.ClientAuth >= tls.VerifyClientCertIfGiven {
tlsConfig.ClientAuth = tls.RequireAnyClientCert
tlsConfig.VerifyPeerCertificate = ci.verifyPeerCertificate
}
} else {
tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return ci.getCertificate(), nil
}
if !tlsConfig.InsecureSkipVerify {
tlsConfig.InsecureSkipVerify = true
tlsConfig.VerifyPeerCertificate = ci.verifyPeerCertificate
}
}
return tlsConfig
}

func (ci *certInfo) getTLS() *tls.Config {
return ci.tlsConfig.Load()
}

func (ci *certInfo) setAutoCertExp(exp time.Time) {
ci.autoCertExp = exp
}

func (ci *certInfo) needRecreateCert(now time.Time) bool {
if !ci.autoCert {
return false
}
return now.After(ci.autoCertExp)
}
58 changes: 4 additions & 54 deletions pkg/manager/cert/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,51 +32,6 @@ const (
defaultAutoCertInterval = 30 * 24 * time.Hour
)

// Security configurations don't support dynamically updating now.
type certInfo struct {
cfg config.TLSConfig
tlsConfig *tls.Config
certificate atomic.Pointer[tls.Certificate]
autoCert bool
autoCertExp time.Time
}

func (ci *certInfo) getTLS() *tls.Config {
if ci.tlsConfig != nil {
return ci.tlsConfig.Clone()
}
return nil
}

func (ci *certInfo) setTLS(tlsConfig *tls.Config) {
if tlsConfig != nil {
tlsConfig = tlsConfig.Clone()
if tlsConfig.Certificates != nil {
ci.certificate.Store(&tlsConfig.Certificates[0])
// Doesn't support rotating CA now. It needs overwriting InsecureSkipVerify and VerifyPeerCertificate.
tlsConfig.GetCertificate = func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return ci.certificate.Load(), nil
}
tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return ci.certificate.Load(), nil
}
tlsConfig.Certificates = nil
}
}
ci.tlsConfig = tlsConfig
}

func (ci *certInfo) setAutoCertExp(exp time.Time) {
ci.autoCertExp = exp
}

func (ci *certInfo) needRecreateCert(now time.Time) bool {
if !ci.autoCert {
return false
}
return now.After(ci.autoCertExp)
}

// CertManager reloads certs and offers interfaces for fetching TLS configs.
// Currently, all the namespaces share the same certs but there might be per-namespace
// certs in the future.
Expand Down Expand Up @@ -109,6 +64,7 @@ func (cm *CertManager) Init(cfg *config.Config, logger *zap.Logger) error {
cm.serverTLS = certInfo{
cfg: cfg.Security.ServerTLS,
autoCert: !cfg.Security.ServerTLS.HasCert() && cfg.Security.ServerTLS.AutoCerts,
isServer: true,
}
cm.peerTLS = certInfo{
cfg: cfg.Security.PeerTLS,
Expand Down Expand Up @@ -185,12 +141,10 @@ func (cm *CertManager) load() error {
needReloadServer = true
}
if needReloadServer {
var tlsConfig *tls.Config
if tlsConfig, err = security.BuildServerTLSConfig(cm.logger, cm.serverTLS.cfg); err != nil {
if err = cm.serverTLS.buildTLSConfig(cm.logger); err != nil {
cm.logger.Error("loading server certs failed", zap.Error(err))
errs = append(errs, err)
} else {
cm.serverTLS.setTLS(tlsConfig)
cm.serverTLS.setAutoCertExp(now.Add(time.Duration(cm.autoCertInterval.Load())))
}
}
Expand All @@ -205,18 +159,14 @@ func (cm *CertManager) load() error {
}
}

if tlsConfig, err := security.BuildClientTLSConfig(cm.logger, cm.sqlTLS.cfg); err != nil {
if err = cm.sqlTLS.buildTLSConfig(cm.logger); err != nil {
cm.logger.Error("loading sql certs failed", zap.Error(err))
errs = append(errs, err)
} else {
cm.sqlTLS.setTLS(tlsConfig)
}

if tlsConfig, err := security.BuildClientTLSConfig(cm.logger, cm.clusterTLS.cfg); err != nil {
if err = cm.clusterTLS.buildTLSConfig(cm.logger); err != nil {
cm.logger.Error("loading cluster certs failed", zap.Error(err))
errs = append(errs, err)
} else {
cm.clusterTLS.setTLS(tlsConfig)
}

if len(errs) != 0 {
Expand Down
Loading

0 comments on commit d3971a8

Please sign in to comment.