Skip to content

Commit

Permalink
*: tls refactor (pingcap#146)
Browse files Browse the repository at this point in the history
  • Loading branch information
xhebox committed Mar 7, 2023
1 parent 8037cf1 commit 34af7ee
Show file tree
Hide file tree
Showing 7 changed files with 366 additions and 264 deletions.
4 changes: 4 additions & 0 deletions .golangci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ issues:
linters:
- gosec
text: "G402:"
- path: util/security/cert.go
linters:
- gosec
text: "G402:"

linters:
enable:
Expand Down
12 changes: 7 additions & 5 deletions lib/config/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,13 @@ type LogFile struct {
}

type TLSConfig struct {
Cert string `yaml:"cert,omitempty" toml:"cert,omitempty" json:"cert,omitempty"`
Key string `yaml:"key,omitempty" toml:"key,omitempty" json:"key,omitempty"`
CA string `yaml:"ca,omitempty" toml:"ca,omitempty" json:"ca,omitempty"`
AutoCerts bool `yaml:"auto-certs,omitempty" toml:"auto-certs,omitempty" json:"auto-certs,omitempty"`
SkipCA bool `yaml:"skip-ca,omitempty" toml:"skip-ca,omitempty" json:"skip-ca,omitempty"`
Cert string `yaml:"cert,omitempty" toml:"cert,omitempty" json:"cert,omitempty"`
Key string `yaml:"key,omitempty" toml:"key,omitempty" json:"key,omitempty"`
CA string `yaml:"ca,omitempty" toml:"ca,omitempty" json:"ca,omitempty"`
AutoCerts bool `yaml:"auto-certs,omitempty" toml:"auto-certs,omitempty" json:"auto-certs,omitempty"`
RSAKeySize int `yaml:"rsa-key-size,omitempty" toml:"rsa-key-size,omitempty" json:"rsa-key-size,omitempty"`
AutoExpireDuration string `yaml:"autocert-expire-duration,omitempty" toml:"autocert-expire-duration,omitempty" json:"autocert-expire-duration,omitempty"`
SkipCA bool `yaml:"skip-ca,omitempty" toml:"skip-ca,omitempty" json:"skip-ca,omitempty"`
}

func (c TLSConfig) HasCert() bool {
Expand Down
283 changes: 283 additions & 0 deletions lib/util/security/cert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
// Copyright 2022 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package security

import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"os"
"sync/atomic"
"time"

"github.com/pingcap/TiProxy/lib/config"
"github.com/pingcap/TiProxy/lib/util/errors"
"go.uber.org/zap"
)

var emptyCert = new(tls.Certificate)

type CertInfo struct {
cfg config.TLSConfig
ca atomic.Value
cert atomic.Value
expire atomic.Int64
server bool
}

func NewCert(lg *zap.Logger, cfg config.TLSConfig, server bool) (*CertInfo, *tls.Config, error) {
ci := &CertInfo{
cfg: cfg,
server: server,
}
tlscfg, err := ci.reload(lg)
return ci, tlscfg, err
}

func (ci *CertInfo) Reload(lg *zap.Logger, n time.Time) error {
if n.Unix() <= ci.expire.Load() {
return nil
}
_, err := ci.reload(lg)
return err
}

func (ci *CertInfo) reload(lg *zap.Logger) (*tls.Config, error) {
var tlsConfig *tls.Config
var err error
// Some methods to rotate server config:
// - For certs: customize GetCertificate / GetConfigForClient.
// - For CA: customize ClientAuth + VerifyPeerCertificate / GetConfigForClient
// Some methods to rotate client config:
// - For certs: customize GetClientCertificate
// - For CA: customize InsecureSkipVerify + VerifyPeerCertificate
if ci.server {
tlsConfig, err = ci.buildServerConfig(lg)
} else {
tlsConfig, err = ci.buildClientConfig(lg)
}
return tlsConfig, err
}

func (ci *CertInfo) getCert(*tls.ClientHelloInfo) (*tls.Certificate, error) {
cert := ci.cert.Load()
if val, ok := cert.(*tls.Certificate); ok {
return val, nil
}
return nil, nil
}

func (ci *CertInfo) getClientCert(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
cert := ci.cert.Load()
if val, ok := cert.(*tls.Certificate); ok {
return val, nil
}
if cert == nil {
// GetClientCertificate must return a non-nil Certificate.
return emptyCert, nil
}
return nil, nil
}

func (ci *CertInfo) verifyPeerCertificate(rawCerts [][]byte, _ [][]*x509.Certificate) error {
if len(rawCerts) == 0 {
return nil
}

certs := make([]*x509.Certificate, len(rawCerts))
for i, asn1Data := range rawCerts {
cert, err := x509.ParseCertificate(asn1Data)
if err != nil {
return errors.New("tls: failed to parse certificate from server: " + err.Error())
}
certs[i] = cert
}

cas, ok := ci.ca.Load().(*x509.CertPool)
if !ok {
cas = x509.NewCertPool()
}
opts := x509.VerifyOptions{
Roots: cas,
Intermediates: x509.NewCertPool(),
}
if ci.server {
opts.KeyUsages = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}
}
// TODO: not implemented, maybe later
// opts.DNSName = ci.serverName
for _, cert := range certs[1:] {
opts.Intermediates.AddCert(cert)
}
_, err := certs[0].Verify(opts)
return err
}

func (ci *CertInfo) updateMinExpire(n int64) {
for o := ci.expire.Load(); o > n && !ci.expire.CompareAndSwap(o, n); o = ci.expire.Load() {
}
}

func (ci *CertInfo) loadCA(pemCerts []byte) (*x509.CertPool, error) {
pool := x509.NewCertPool()
for len(pemCerts) > 0 {
var block *pem.Block
block, pemCerts = pem.Decode(pemCerts)
if block == nil {
break
}
if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
continue
}

certBytes := block.Bytes
cert, err := x509.ParseCertificate(certBytes)
if err != nil {
continue
}

ci.updateMinExpire(cert.NotAfter.Unix())

pool.AddCert(cert)
}
return pool, nil
}

func (ci *CertInfo) buildServerConfig(lg *zap.Logger) (*tls.Config, error) {
lg = lg.With(zap.String("tls", "server"), zap.Any("cfg", ci.cfg))
autoCerts := false
if !ci.cfg.HasCert() {
if ci.cfg.AutoCerts {
autoCerts = true
} else {
lg.Warn("require certificates to secure clients connections, disable TLS")
return nil, nil
}
}

tcfg := &tls.Config{
MinVersion: tls.VersionTLS12,
GetCertificate: ci.getCert,
GetClientCertificate: ci.getClientCert,
VerifyPeerCertificate: ci.verifyPeerCertificate,
}

var certPEM, keyPEM, caPEM []byte
var err error
if autoCerts {
dur, err := time.ParseDuration(ci.cfg.AutoExpireDuration)
if err != nil {
dur = DefaultCertExpiration
}
certPEM, keyPEM, caPEM, err = CreateTempTLS(ci.cfg.RSAKeySize, dur)
if err != nil {
return nil, err
}
} else {
certPEM, err = os.ReadFile(ci.cfg.Cert)
if err != nil {
return nil, err
}
keyPEM, err = os.ReadFile(ci.cfg.Key)
if err != nil {
return nil, err
}
if !ci.cfg.HasCA() {
lg.Warn("no CA, server will not authenticate clients (connection is still secured)")
return tcfg, nil
} else {
caPEM, err = os.ReadFile(ci.cfg.CA)
if err != nil {
return nil, err
}
}
}

cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
return nil, errors.WithStack(err)
}
for _, c := range cert.Certificate {
cp, err := x509.ParseCertificate(c)
if err != nil {
return nil, errors.WithStack(err)
}
ci.updateMinExpire(cp.NotAfter.Unix())
}
ci.cert.Store(&cert)

if len(caPEM) != 0 {
cas, err := ci.loadCA(caPEM)
if err != nil {
return nil, errors.WithStack(err)
}
ci.ca.Store(cas)
tcfg.ClientAuth = tls.RequireAnyClientCert
}
return tcfg, nil
}

func (ci *CertInfo) buildClientConfig(lg *zap.Logger) (*tls.Config, error) {
lg = lg.With(zap.String("tls", "client"), zap.Any("cfg", ci.cfg))
if !ci.cfg.HasCA() {
if ci.cfg.SkipCA {
// still enable TLS without verify server certs
return &tls.Config{
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS12,
}, nil
}
lg.Warn("no CA to verify server connections, disable TLS")
return nil, nil
}

tcfg := &tls.Config{
MinVersion: tls.VersionTLS12,
GetCertificate: ci.getCert,
GetClientCertificate: ci.getClientCert,
InsecureSkipVerify: true,
VerifyPeerCertificate: ci.verifyPeerCertificate,
}

certBytes, err := os.ReadFile(ci.cfg.CA)
if err != nil {
return nil, errors.WithStack(err)
}
cas, err := ci.loadCA(certBytes)
if err != nil {
return nil, errors.WithStack(err)
}
ci.ca.Store(cas)

if !ci.cfg.HasCert() {
lg.Warn("no certificates, server may reject the connection")
return tcfg, nil
}

cert, err := tls.LoadX509KeyPair(ci.cfg.Cert, ci.cfg.Key)
if err != nil {
return nil, errors.WithStack(err)
}
for _, c := range cert.Certificate {
cp, err := x509.ParseCertificate(c)
if err != nil {
return nil, errors.WithStack(err)
}
ci.updateMinExpire(cp.NotAfter.Unix())
}
ci.cert.Store(&cert)

return tcfg, nil
}
14 changes: 7 additions & 7 deletions lib/util/security/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ func createTLSCertificates(logger *zap.Logger, certpath, keypath, capath string,
return err
}

if err := os.WriteFile(certpath, certPEM.Bytes(), 0600); err != nil {
if err := os.WriteFile(certpath, certPEM, 0600); err != nil {
return err
}
if err := os.WriteFile(keypath, keyPEM.Bytes(), 0600); err != nil {
if err := os.WriteFile(keypath, keyPEM, 0600); err != nil {
return err
}
if capath != "" {
if err := os.WriteFile(capath, caPEM.Bytes(), 0600); err != nil {
if err := os.WriteFile(capath, caPEM, 0600); err != nil {
return err
}
}
Expand All @@ -99,7 +99,7 @@ func AutoTLS(logger *zap.Logger, scfg *config.TLSConfig, autoca bool, workdir, m
return nil
}

func CreateTempTLS(rsaKeySize int, expiration time.Duration) (*bytes.Buffer, *bytes.Buffer, *bytes.Buffer, error) {
func CreateTempTLS(rsaKeySize int, expiration time.Duration) ([]byte, []byte, []byte, error) {
if rsaKeySize < 1024 {
rsaKeySize = 1024
}
Expand Down Expand Up @@ -189,7 +189,7 @@ func CreateTempTLS(rsaKeySize int, expiration time.Duration) (*bytes.Buffer, *by
return nil, nil, nil, err
}

return certPEM, keyPEM, caPEM, nil
return certPEM.Bytes(), keyPEM.Bytes(), caPEM.Bytes(), nil
}

// CreateTLSConfigForTest is from https://gist.github.com/shaneutt/5e1995295cff6721c89a71d13a71c251.
Expand All @@ -200,7 +200,7 @@ func CreateTLSConfigForTest() (serverTLSConf *tls.Config, clientTLSConf *tls.Con
return
}

serverCert, uerr := tls.X509KeyPair(certPEM.Bytes(), keyPEM.Bytes())
serverCert, uerr := tls.X509KeyPair(certPEM, keyPEM)
if uerr != nil {
err = uerr
return
Expand All @@ -212,7 +212,7 @@ func CreateTLSConfigForTest() (serverTLSConf *tls.Config, clientTLSConf *tls.Con
}

certpool := x509.NewCertPool()
certpool.AppendCertsFromPEM(caPEM.Bytes())
certpool.AppendCertsFromPEM(caPEM)
clientTLSConf = &tls.Config{
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: true,
Expand Down
Loading

0 comments on commit 34af7ee

Please sign in to comment.