Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamically refresh tls certs for all servers #3598

Merged
merged 2 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions internal/certloader/certloader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Package certloader contains a certicate loader.
package certloader

import (
"crypto/tls"
"sync"

"github.com/bluenviron/mediamtx/internal/confwatcher"
"github.com/bluenviron/mediamtx/internal/logger"
)

// CertLoader is a certificate loader. It watches for changes to the certificate and key files.
type CertLoader struct {
log logger.Writer
certWatcher, keyWatcher *confwatcher.ConfWatcher
certPath, keyPath string
done chan struct{}

cert *tls.Certificate
certMu sync.RWMutex
}

// New allocates a CertLoader.
func New(certPath, keyPath string, log logger.Writer) (*CertLoader, error) {
cl := &CertLoader{
log: log,
certPath: certPath,
keyPath: keyPath,
done: make(chan struct{}),
}

var err error
cl.certWatcher, err = confwatcher.New(certPath)
if err != nil {
return nil, err
}

cl.keyWatcher, err = confwatcher.New(keyPath)
if err != nil {
cl.certWatcher.Close() //nolint:errcheck
return nil, err
}

cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, err
}

cl.certMu.Lock()
cl.cert = &cert
cl.certMu.Unlock()

go cl.watch()

return cl, nil
}

// Close closes a CertLoader and releases any underlying resources.
func (cl *CertLoader) Close() {
close(cl.done)
cl.certWatcher.Close() //nolint:errcheck
cl.keyWatcher.Close() //nolint:errcheck
cl.certMu.Lock()
defer cl.certMu.Unlock()
cl.cert = nil
}

// GetCertificate returns a function that returns the certificate for use in a tls.Config.
func (cl *CertLoader) GetCertificate() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
cl.certMu.RLock()
defer cl.certMu.RUnlock()
return cl.cert, nil
}
}

func (cl *CertLoader) watch() {
for {
select {
case <-cl.certWatcher.Watch():
cert, err := tls.LoadX509KeyPair(cl.certPath, cl.keyPath)
if err != nil {
cl.log.Log(logger.Error, "certloader failed to load after change to %s: %s", cl.certPath, err.Error())
continue
}

cl.certMu.Lock()
cl.cert = &cert
cl.certMu.Unlock()

cl.log.Log(logger.Info, "certificate reloaded after change to %s", cl.certPath)
case <-cl.keyWatcher.Watch():
cert, err := tls.LoadX509KeyPair(cl.certPath, cl.keyPath)
if err != nil {
cl.log.Log(logger.Error, "certloader failed to load after change to %s: %s", cl.keyPath, err.Error())
continue
}

cl.certMu.Lock()
cl.cert = &cert
cl.certMu.Unlock()

cl.log.Log(logger.Info, "certificate reloaded after change to %s", cl.keyPath)
case <-cl.done:
return
}
}
}
52 changes: 52 additions & 0 deletions internal/certloader/certloader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package certloader

import (
"crypto/tls"
"os"
"testing"
"time"

"github.com/bluenviron/mediamtx/internal/test"
"github.com/stretchr/testify/require"
)

func TestCertReload(t *testing.T) {
testData, err := tls.X509KeyPair(test.TLSCertPub, test.TLSCertKey)
require.NoError(t, err)

serverCertPath, err := test.CreateTempFile(test.TLSCertPub)
require.NoError(t, err)
defer os.Remove(serverCertPath)

serverKeyPath, err := test.CreateTempFile(test.TLSCertKey)
require.NoError(t, err)
defer os.Remove(serverKeyPath)

loader, err := New(serverCertPath, serverKeyPath, test.NilLogger)
require.NoError(t, err)
defer loader.Close()

getCert := loader.GetCertificate()
require.NotNil(t, getCert)

cert, err := getCert(nil)
require.NoError(t, err)
require.NotNil(t, cert)
require.Equal(t, &testData, cert)

testData, err = tls.X509KeyPair(test.TLSCertPubAlt, test.TLSCertKeyAlt)
require.NoError(t, err)

err = os.WriteFile(serverCertPath, test.TLSCertPubAlt, 0o644)
require.NoError(t, err)

err = os.WriteFile(serverKeyPath, test.TLSCertKeyAlt, 0o644)
require.NoError(t, err)

time.Sleep(1 * time.Second)

cert, err = getCert(nil)
require.NoError(t, err)
require.NotNil(t, cert)
require.Equal(t, &testData, cert)
}
15 changes: 11 additions & 4 deletions internal/protocols/httpp/wrapped_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net/http"
"time"

"github.com/bluenviron/mediamtx/internal/certloader"
"github.com/bluenviron/mediamtx/internal/logger"
)

Expand All @@ -36,8 +37,9 @@ type WrappedServer struct {
Handler http.Handler
Parent logger.Writer

ln net.Listener
inner *http.Server
ln net.Listener
inner *http.Server
loader *certloader.CertLoader
}

// Initialize initializes a WrappedServer.
Expand All @@ -47,13 +49,15 @@ func (s *WrappedServer) Initialize() error {
if s.ServerCert == "" {
return fmt.Errorf("server cert is missing")
}
crt, err := tls.LoadX509KeyPair(s.ServerCert, s.ServerKey)

var err error
s.loader, err = certloader.New(s.ServerCert, s.ServerKey, s.Parent)
if err != nil {
return err
}

tlsConfig = &tls.Config{
Certificates: []tls.Certificate{crt},
GetCertificate: s.loader.GetCertificate(),
}
}

Expand Down Expand Up @@ -92,4 +96,7 @@ func (s *WrappedServer) Close() {
ctxCancel()
s.inner.Shutdown(ctx)
s.ln.Close() // in case Shutdown() is called before Serve()
if s.loader != nil {
s.loader.Close()
}
}
10 changes: 8 additions & 2 deletions internal/servers/rtmp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/google/uuid"

"github.com/bluenviron/mediamtx/internal/certloader"
"github.com/bluenviron/mediamtx/internal/conf"
"github.com/bluenviron/mediamtx/internal/defs"
"github.com/bluenviron/mediamtx/internal/externalcmd"
Expand Down Expand Up @@ -82,6 +83,7 @@ type Server struct {
wg sync.WaitGroup
ln net.Listener
conns map[*conn]struct{}
loader *certloader.CertLoader

// in
chNewConn chan net.Conn
Expand All @@ -99,13 +101,14 @@ func (s *Server) Initialize() error {
return net.Listen(restrictnetwork.Restrict("tcp", s.Address))
}

cert, err := tls.LoadX509KeyPair(s.ServerCert, s.ServerKey)
var err error
s.loader, err = certloader.New(s.ServerCert, s.ServerKey, s.Parent)
if err != nil {
return nil, err
}

network, address := restrictnetwork.Restrict("tcp", s.Address)
return tls.Listen(network, address, &tls.Config{Certificates: []tls.Certificate{cert}})
return tls.Listen(network, address, &tls.Config{GetCertificate: s.loader.GetCertificate()})
}()
if err != nil {
return err
Expand Down Expand Up @@ -153,6 +156,9 @@ func (s *Server) Close() {
s.Log(logger.Info, "listener is closing")
s.ctxCancel()
s.wg.Wait()
if s.loader != nil {
s.loader.Close()
}
}

func (s *Server) run() {
Expand Down
10 changes: 8 additions & 2 deletions internal/servers/rtsp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/liberrors"
"github.com/google/uuid"

"github.com/bluenviron/mediamtx/internal/certloader"
"github.com/bluenviron/mediamtx/internal/conf"
"github.com/bluenviron/mediamtx/internal/defs"
"github.com/bluenviron/mediamtx/internal/externalcmd"
Expand Down Expand Up @@ -89,6 +90,7 @@ type Server struct {
mutex sync.RWMutex
conns map[*gortsplib.ServerConn]*conn
sessions map[*gortsplib.ServerSession]*session
loader *certloader.CertLoader
}

// Initialize initializes the server.
Expand Down Expand Up @@ -118,12 +120,13 @@ func (s *Server) Initialize() error {
}

if s.IsTLS {
cert, err := tls.LoadX509KeyPair(s.ServerCert, s.ServerKey)
var err error
s.loader, err = certloader.New(s.ServerCert, s.ServerKey, s.Parent)
if err != nil {
return err
}

s.srv.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
s.srv.TLSConfig = &tls.Config{GetCertificate: s.loader.GetCertificate()}
}

err := s.srv.Start()
Expand Down Expand Up @@ -155,6 +158,9 @@ func (s *Server) Close() {
s.Log(logger.Info, "listener is closing")
s.ctxCancel()
s.wg.Wait()
if s.loader != nil {
s.loader.Close()
}
}

func (s *Server) run() {
Expand Down
52 changes: 52 additions & 0 deletions internal/test/tls_cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,55 @@ y++U32uuSFiXDcSLarfIsE992MEJLSAynbF1Rsgsr3gXbGiuToJRyxbIeVy7gwzD
+3K6cnKEyg+0ekYmLertRFIY6SwWmY1fyKgTvxudMcsBY7dC4xs=
-----END RSA PRIVATE KEY-----
`)

// TLSCertPubAlt is the public key of an alternative test certificate.
var TLSCertPubAlt = []byte(`-----BEGIN CERTIFICATE-----
MIIDSTCCAjECFEut6ZxIOnbxi3bhrPLfPQZCLReNMA0GCSqGSIb3DQEBCwUAMGEx
CzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRl
cm5ldCBXaWRnaXRzIFB0eSBMdGQxGjAYBgNVBAMMEW1lZGlhbXR4LnRlc3QuY29t
MB4XDTI0MDgwMTIzNDY0MloXDTM0MDczMDIzNDY0MlowYTELMAkGA1UEBhMCQVUx
EzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMg
UHR5IEx0ZDEaMBgGA1UEAwwRbWVkaWFtdHgudGVzdC5jb20wggEiMA0GCSqGSIb3
DQEBAQUAA4IBDwAwggEKAoIBAQCzfvG9eLXKSTDBoM+cgV/ThiNRI2JY6dpQV8rK
QFQ5bkkDUDP+2Ae/IWylgLLXmozsMwjz1Pu42awmGymBuo5HDbI4bxPJNQR9qRrR
2+MvfDgmZxyhw5NfZDlVl+enxhb3FRgbHsLBy4oSoHbRUdLApVdM0Kg6r3bXzkih
EEs63boFJOkPhs5H0NX7AzXyBp2WnvB71j+7avnMwAsjJHOiTs8wkp5wvRcIZpJl
MCandUkcZShMirug7QOcR9fAr5CVKxsO/DjqEjwkslJHFfizOl3yRx6nsxvW8JUd
dforpSRj84dkHTi7k37YTiji90GsOvh0qc0MfAmeE181HIb/AgMBAAEwDQYJKoZI
hvcNAQELBQADggEBAEWkLL/7nvt3iD7BVJNHLvAS6GwuTH99vCil6TFYwVl4goht
Dur7YfzN43vUq+lAwS3Ry4ka7tH72pAMkpNFRvHOikWGmWUSDo2DcLd8iu3ruLF7
yUg2ASQuekK0sUv4YKpAqV8gS2R4Jh4vLU+8L5iJ1XWGELbQ+H5wm4l7l+r2X6cD
/opmdV8Slfi0FlNQtflLsGoSlfZF5jHxqi3zyt8QdEf9WZt8e6JPxcx2Fq7Op51u
Qx9nosr5fLwhkx46+B/cotsbI/xPDjLF6RQ1OUpcHwg1HI6czoW4hHn33S0zstCf
BWt5Q1Mb2tGInbmbUgw3wUu/4nWoY+Mq4DKPlKs=
-----END CERTIFICATE-----`)

// TLSCertKeyAlt is the private key of an alternative test certificate.
var TLSCertKeyAlt = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIEoQIBAAKCAQEAs37xvXi1ykkwwaDPnIFf04YjUSNiWOnaUFfKykBUOW5JA1Az
/tgHvyFspYCy15qM7DMI89T7uNmsJhspgbqORw2yOG8TyTUEfaka0dvjL3w4Jmcc
ocOTX2Q5VZfnp8YW9xUYGx7CwcuKEqB20VHSwKVXTNCoOq92185IoRBLOt26BSTp
D4bOR9DV+wM18gadlp7we9Y/u2r5zMALIyRzok7PMJKecL0XCGaSZTAmp3VJHGUo
TIq7oO0DnEfXwK+QlSsbDvw46hI8JLJSRxX4szpd8kcep7Mb1vCVHXX6K6UkY/OH
ZB04u5N+2E4o4vdBrDr4dKnNDHwJnhNfNRyG/wIDAQABAoH/WmCqV6Lv5dEnofCj
ZUO/Fdv0hf/LBS0g2SAoFRSCIM8aJ3dUUH0PaXoeINDGCMlIxT7tKXJg5jJNYhWx
g7oegw6vLe5ZiA+p5miL/uue+Jas4kLVp9DrfQLgQevt0gw4g/00pgy9adbFlTUD
a2HhPB7RIvXs8gYA6nVAT9jK1ST2pbeUgQNO4Ji4EjpPUkR2O7ISOlu5EV8Cj0eV
1Vs5B92Z7ORh7P2fFV2YBu+igd04+uYvei6slQl+F9cETvJv2Z9r37Yashvnn1in
uy/u1U4B1t4oOz81nHz6kxTixPpBOdJ6x8jLDgNGSsauJQfXT9xmB/rAr/NFq+7I
tbTNAoGBAMOgm3XXHWokmJnX9pfNj6ixNlrMuuez/yXMVwuxa2WFwAFN16tjJhBi
XOjestcvu/SRhOAMmYac5QdopJpLjO/FxO165r73eZhW/SJefyOHtfD29kHagA1u
JjcznU6tiA0O1owy6nuuaTfyVbDQj32PhVBx9ZwSI4778GFbjWl7AoGBAOrj4WCC
gTMaExpwNo+L+3VkM79YD1Obl13FcgtVoxjcoWjQeMx9D0k7adTV3xlchHFAjiD5
Gs/MZl8+seq+GDX3mODsmJkdRQbYId4g6IesiOnQ3Ug/Y282WZRnpB5h/BMnrcCZ
VoohnATA7f96c7XtPUgZyROmh24T7UIVwVdNAoGAbeeGT276TI6g2RWWqXRIOFrP
EbYhb1kViFPDt4MGtjOtSk5EUzpRwTSxw/aRfQmJS/6RKxqJCjKNDVuB1lmJpY9z
coPwrOr1+lssvalfPkPZOLZWZWrvNBxlBfBOeUxOuh9S89MLH08+N7tC3yJc6wq9
uBM+DF+4cHUkeF3qFY8CgYBzS+IwBj82/0CLRLNzaKnIqKPB846qYoA9NhLRv3ps
VLgiA9qXvXdIYhKDt2toPoKAOMjLJJtljpZdgB/C8wZdTyjKlzgcSEK+pk6RgyPA
nQ8jfjNwKDU9vLbh4rGrfDtIh7yBAoN5ECBOMQlh0xCDJ21iO834iFCH1t4qBxW9
LQKBgQC36adC2Gu+FJRvx4Mkm73fLmVdFbP6Do7qNwyVVyaG80PDVrFQrlWm4Dt7
AO9IwzaS1Lx+qmU1Fj1WfCtXuQa5nc9AzZ36TmM6+pAn8AC7PdNqc0qSdefVrIjj
zRGhUPaJV3A+sfO+xedBsAFnqNuX9oODYVGbTjuc2OWC30MGaw==
-----END RSA PRIVATE KEY-----
`)
Loading