Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

为CA证书启用自动重载 #3607

Merged
merged 5 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions common/protocol/tls/cert/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,6 @@ func Generate(parent *Certificate, opts ...Option) (*Certificate, error) {
BasicConstraintsValid: true,
}

for _, opt := range opts {
opt(template)
}

parentCert := template
if parent != nil {
pCert, err := x509.ParseCertificate(parent.Certificate)
Expand All @@ -162,6 +158,17 @@ func Generate(parent *Certificate, opts ...Option) (*Certificate, error) {
parentCert = pCert
}

if parentCert.NotAfter.Before(template.NotAfter) {
template.NotAfter = parentCert.NotAfter
}
if parentCert.NotBefore.After(template.NotBefore) {
template.NotBefore = parentCert.NotBefore
}

for _, opt := range opts {
opt(template)
}

derBytes, err := x509.CreateCertificate(rand.Reader, template, parentCert, publicKey(selfKey), parentKey)
if err != nil {
return nil, errors.New("failed to create certificate").Base(err)
Expand Down
2 changes: 2 additions & 0 deletions infra/conf/transport_internet.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ type TLSCertConfig struct {
Usage string `json:"usage"`
OcspStapling uint64 `json:"ocspStapling"`
OneTimeLoading bool `json:"oneTimeLoading"`
BuildChain bool `json:"buildChain"`
}

// Build implements Buildable.
Expand Down Expand Up @@ -415,6 +416,7 @@ func (c *TLSCertConfig) Build() (*tls.Certificate, error) {
certificate.OneTimeLoading = c.OneTimeLoading
}
certificate.OcspStapling = c.OcspStapling
certificate.BuildChain = c.BuildChain

return certificate, nil
}
Expand Down
131 changes: 74 additions & 57 deletions transport/internet/tls/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strings"
"sync"
"time"
"bytes"

"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
Expand Down Expand Up @@ -50,72 +51,84 @@ func (c *Config) BuildCertificates() []*tls.Certificate {
if entry.Usage != Certificate_ENCIPHERMENT {
continue
}
keyPair, err := tls.X509KeyPair(entry.Certificate, entry.Key)
if err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid X509 key pair")
continue
getX509KeyPair := func() *tls.Certificate {
keyPair, err := tls.X509KeyPair(entry.Certificate, entry.Key)
if err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid X509 key pair")
return nil
}
keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
if err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid certificate")
return nil
}
return &keyPair
}
keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
if err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid certificate")
if keyPair := getX509KeyPair(); keyPair != nil {
certs = append(certs, keyPair)
} else {
continue
}
certs = append(certs, &keyPair)
if !entry.OneTimeLoading {
var isOcspstapling bool
hotReloadCertInterval := uint64(3600)
if entry.OcspStapling != 0 {
hotReloadCertInterval = entry.OcspStapling
isOcspstapling = true
index := len(certs) - 1
setupOcspTicker(entry, func(isReloaded, isOcspstapling bool){
cert := certs[index]
if isReloaded {
if newKeyPair := getX509KeyPair(); newKeyPair != nil {
cert = newKeyPair
} else {
return
}
}
index := len(certs) - 1
go func(entry *Certificate, cert *tls.Certificate, index int) {
t := time.NewTicker(time.Duration(hotReloadCertInterval) * time.Second)
for {
if entry.CertificatePath != "" && entry.KeyPath != "" {
newCert, err := filesystem.ReadFile(entry.CertificatePath)
if err != nil {
errors.LogErrorInner(context.Background(), err, "failed to parse certificate")
<-t.C
continue
}
newKey, err := filesystem.ReadFile(entry.KeyPath)
if err != nil {
errors.LogErrorInner(context.Background(), err, "failed to parse key")
<-t.C
continue
}
if string(newCert) != string(entry.Certificate) && string(newKey) != string(entry.Key) {
newKeyPair, err := tls.X509KeyPair(newCert, newKey)
if err != nil {
errors.LogErrorInner(context.Background(), err, "ignoring invalid X509 key pair")
<-t.C
continue
}
if newKeyPair.Leaf, err = x509.ParseCertificate(newKeyPair.Certificate[0]); err != nil {
errors.LogErrorInner(context.Background(), err, "ignoring invalid certificate")
<-t.C
continue
}
cert = &newKeyPair
}
}
if isOcspstapling {
if newOCSPData, err := ocsp.GetOCSPForCert(cert.Certificate); err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid OCSP")
} else if string(newOCSPData) != string(cert.OCSPStaple) {
cert.OCSPStaple = newOCSPData
}
}
certs[index] = cert
<-t.C
if isOcspstapling {
if newOCSPData, err := ocsp.GetOCSPForCert(cert.Certificate); err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid OCSP")
} else if string(newOCSPData) != string(cert.OCSPStaple) {
cert.OCSPStaple = newOCSPData
}
}(entry, certs[index], index)
}
}
certs[index] = cert
})
}
return certs
}

func setupOcspTicker(entry *Certificate, callback func(isReloaded, isOcspstapling bool)) {
go func() {
if entry.OneTimeLoading {
return
}
var isOcspstapling bool
hotReloadCertInterval := uint64(3600)
if entry.OcspStapling != 0 {
hotReloadCertInterval = entry.OcspStapling
isOcspstapling = true
}
t := time.NewTicker(time.Duration(hotReloadCertInterval) * time.Second)
for {
var isReloaded bool
if entry.CertificatePath != "" && entry.KeyPath != "" {
newCert, err := filesystem.ReadFile(entry.CertificatePath)
if err != nil {
errors.LogErrorInner(context.Background(), err, "failed to parse certificate")
return
}
newKey, err := filesystem.ReadFile(entry.KeyPath)
if err != nil {
errors.LogErrorInner(context.Background(), err, "failed to parse key")
return
}
if string(newCert) != string(entry.Certificate) || string(newKey) != string(entry.Key) {
entry.Certificate = newCert
entry.Key = newKey
isReloaded = true
}
}
callback(isReloaded, isOcspstapling)
<-t.C
}
}()
}

func isCertificateExpired(c *tls.Certificate) bool {
if c.Leaf == nil && len(c.Certificate) > 0 {
if pc, err := x509.ParseCertificate(c.Certificate[0]); err == nil {
Expand All @@ -137,6 +150,9 @@ func issueCertificate(rawCA *Certificate, domain string) (*tls.Certificate, erro
return nil, errors.New("failed to generate new certificate for ", domain).Base(err)
}
newCertPEM, newKeyPEM := newCert.ToPEM()
if rawCA.BuildChain {
newCertPEM = bytes.Join([][]byte{newCertPEM, rawCA.Certificate}, []byte("\n"))
}
cert, err := tls.X509KeyPair(newCertPEM, newKeyPEM)
return &cert, err
}
Expand All @@ -146,6 +162,7 @@ func (c *Config) getCustomCA() []*Certificate {
for _, certificate := range c.Certificate {
if certificate.Usage == Certificate_AUTHORITY_ISSUE {
certs = append(certs, certificate)
setupOcspTicker(certificate, func(isReloaded, isOcspstapling bool){ })
}
}
return certs
Expand Down
Loading