Skip to content

Commit

Permalink
CertificateObject: Enable auto-reload for cacert & Add buildChain & F…
Browse files Browse the repository at this point in the history
…ixes (XTLS#3607)
  • Loading branch information
lelemka0 authored and leninalive committed Oct 29, 2024
1 parent da0cbe8 commit a41eab2
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 122 deletions.
15 changes: 11 additions & 4 deletions common/protocol/tls/cert/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,6 @@ func Generate(parent *Certificate, opts ...Option) (*Certificate, error) {
BasicConstraintsValid: true,
}

for _, opt := range opts {
opt(template)
}

parentCert := template
if parent != nil {
pCert, err := x509.ParseCertificate(parent.Certificate)
Expand All @@ -162,6 +158,17 @@ func Generate(parent *Certificate, opts ...Option) (*Certificate, error) {
parentCert = pCert
}

if parentCert.NotAfter.Before(template.NotAfter) {
template.NotAfter = parentCert.NotAfter
}
if parentCert.NotBefore.After(template.NotBefore) {
template.NotBefore = parentCert.NotBefore
}

for _, opt := range opts {
opt(template)
}

derBytes, err := x509.CreateCertificate(rand.Reader, template, parentCert, publicKey(selfKey), parentKey)
if err != nil {
return nil, errors.New("failed to create certificate").Base(err)
Expand Down
2 changes: 2 additions & 0 deletions infra/conf/transport_internet.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ type TLSCertConfig struct {
Usage string `json:"usage"`
OcspStapling uint64 `json:"ocspStapling"`
OneTimeLoading bool `json:"oneTimeLoading"`
BuildChain bool `json:"buildChain"`
}

// Build implements Buildable.
Expand Down Expand Up @@ -423,6 +424,7 @@ func (c *TLSCertConfig) Build() (*tls.Certificate, error) {
certificate.OneTimeLoading = c.OneTimeLoading
}
certificate.OcspStapling = c.OcspStapling
certificate.BuildChain = c.BuildChain

return certificate, nil
}
Expand Down
131 changes: 74 additions & 57 deletions transport/internet/tls/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strings"
"sync"
"time"
"bytes"

"github.com/amnezia-vpn/amnezia-xray-core/common/errors"
"github.com/amnezia-vpn/amnezia-xray-core/common/net"
Expand Down Expand Up @@ -50,72 +51,84 @@ func (c *Config) BuildCertificates() []*tls.Certificate {
if entry.Usage != Certificate_ENCIPHERMENT {
continue
}
keyPair, err := tls.X509KeyPair(entry.Certificate, entry.Key)
if err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid X509 key pair")
continue
getX509KeyPair := func() *tls.Certificate {
keyPair, err := tls.X509KeyPair(entry.Certificate, entry.Key)
if err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid X509 key pair")
return nil
}
keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
if err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid certificate")
return nil
}
return &keyPair
}
keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
if err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid certificate")
if keyPair := getX509KeyPair(); keyPair != nil {
certs = append(certs, keyPair)
} else {
continue
}
certs = append(certs, &keyPair)
if !entry.OneTimeLoading {
var isOcspstapling bool
hotReloadCertInterval := uint64(3600)
if entry.OcspStapling != 0 {
hotReloadCertInterval = entry.OcspStapling
isOcspstapling = true
index := len(certs) - 1
setupOcspTicker(entry, func(isReloaded, isOcspstapling bool){
cert := certs[index]
if isReloaded {
if newKeyPair := getX509KeyPair(); newKeyPair != nil {
cert = newKeyPair
} else {
return
}
}
index := len(certs) - 1
go func(entry *Certificate, cert *tls.Certificate, index int) {
t := time.NewTicker(time.Duration(hotReloadCertInterval) * time.Second)
for {
if entry.CertificatePath != "" && entry.KeyPath != "" {
newCert, err := filesystem.ReadFile(entry.CertificatePath)
if err != nil {
errors.LogErrorInner(context.Background(), err, "failed to parse certificate")
<-t.C
continue
}
newKey, err := filesystem.ReadFile(entry.KeyPath)
if err != nil {
errors.LogErrorInner(context.Background(), err, "failed to parse key")
<-t.C
continue
}
if string(newCert) != string(entry.Certificate) && string(newKey) != string(entry.Key) {
newKeyPair, err := tls.X509KeyPair(newCert, newKey)
if err != nil {
errors.LogErrorInner(context.Background(), err, "ignoring invalid X509 key pair")
<-t.C
continue
}
if newKeyPair.Leaf, err = x509.ParseCertificate(newKeyPair.Certificate[0]); err != nil {
errors.LogErrorInner(context.Background(), err, "ignoring invalid certificate")
<-t.C
continue
}
cert = &newKeyPair
}
}
if isOcspstapling {
if newOCSPData, err := ocsp.GetOCSPForCert(cert.Certificate); err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid OCSP")
} else if string(newOCSPData) != string(cert.OCSPStaple) {
cert.OCSPStaple = newOCSPData
}
}
certs[index] = cert
<-t.C
if isOcspstapling {
if newOCSPData, err := ocsp.GetOCSPForCert(cert.Certificate); err != nil {
errors.LogWarningInner(context.Background(), err, "ignoring invalid OCSP")
} else if string(newOCSPData) != string(cert.OCSPStaple) {
cert.OCSPStaple = newOCSPData
}
}(entry, certs[index], index)
}
}
certs[index] = cert
})
}
return certs
}

func setupOcspTicker(entry *Certificate, callback func(isReloaded, isOcspstapling bool)) {
go func() {
if entry.OneTimeLoading {
return
}
var isOcspstapling bool
hotReloadCertInterval := uint64(3600)
if entry.OcspStapling != 0 {
hotReloadCertInterval = entry.OcspStapling
isOcspstapling = true
}
t := time.NewTicker(time.Duration(hotReloadCertInterval) * time.Second)
for {
var isReloaded bool
if entry.CertificatePath != "" && entry.KeyPath != "" {
newCert, err := filesystem.ReadFile(entry.CertificatePath)
if err != nil {
errors.LogErrorInner(context.Background(), err, "failed to parse certificate")
return
}
newKey, err := filesystem.ReadFile(entry.KeyPath)
if err != nil {
errors.LogErrorInner(context.Background(), err, "failed to parse key")
return
}
if string(newCert) != string(entry.Certificate) || string(newKey) != string(entry.Key) {
entry.Certificate = newCert
entry.Key = newKey
isReloaded = true
}
}
callback(isReloaded, isOcspstapling)
<-t.C
}
}()
}

func isCertificateExpired(c *tls.Certificate) bool {
if c.Leaf == nil && len(c.Certificate) > 0 {
if pc, err := x509.ParseCertificate(c.Certificate[0]); err == nil {
Expand All @@ -137,6 +150,9 @@ func issueCertificate(rawCA *Certificate, domain string) (*tls.Certificate, erro
return nil, errors.New("failed to generate new certificate for ", domain).Base(err)
}
newCertPEM, newKeyPEM := newCert.ToPEM()
if rawCA.BuildChain {
newCertPEM = bytes.Join([][]byte{newCertPEM, rawCA.Certificate}, []byte("\n"))
}
cert, err := tls.X509KeyPair(newCertPEM, newKeyPEM)
return &cert, err
}
Expand All @@ -146,6 +162,7 @@ func (c *Config) getCustomCA() []*Certificate {
for _, certificate := range c.Certificate {
if certificate.Usage == Certificate_AUTHORITY_ISSUE {
certs = append(certs, certificate)
setupOcspTicker(certificate, func(isReloaded, isOcspstapling bool){ })
}
}
return certs
Expand Down
Loading

0 comments on commit a41eab2

Please sign in to comment.