Skip to content

Commit

Permalink
optimize expiry periods of certificates (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Jul 16, 2022
1 parent ff5aa30 commit 2823159
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 67 deletions.
106 changes: 58 additions & 48 deletions p2p/transport/webtransport/cert_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,19 @@ import (
"github.com/multiformats/go-multihash"
)

// Allow for a bit of clock skew.
// When we generate a certificate, the NotBefore time is set to clockSkewAllowance before the current time.
// Similarly, we stop using a certificate one clockSkewAllowance before its expiry time.
const clockSkewAllowance = time.Hour

type certConfig struct {
tlsConf *tls.Config
sha256 [32]byte // cached from the tlsConf
}

func (c *certConfig) Start() time.Time { return c.tlsConf.Certificates[0].Leaf.NotBefore }
func (c *certConfig) End() time.Time { return c.tlsConf.Certificates[0].Leaf.NotAfter }

func newCertConfig(start, end time.Time) (*certConfig, error) {
conf, err := getTLSConf(start, end)
if err != nil {
Expand All @@ -32,22 +40,17 @@ func newCertConfig(start, end time.Time) (*certConfig, error) {
}

// Certificate renewal logic:
// 0. To simplify the math, assume the certificate is valid for 10 days (in real life: 14 days).
// 1. On startup, we generate the first certificate (1).
// 2. After 4 days, we generate a second certificate (2).
// We don't use that certificate yet, but we advertise the hashes of (1) and (2).
// That allows clients to connect to us using addresses that are 4 days old.
// 3. After another 4 days, we now actually start using (2).
// We also generate a third certificate (3), and start advertising the hashes of (2) and (3).
// We continue to remember the hash of (1) for validation during the Noise handshake for another 4 days,
// as the client might be connecting with a cached address.
// 1. On startup, we generate one cert that is valid from now (-1h, to allow for clock skew), and another
// cert that is valid from the expiry date of the first certificate (again, with allowance for clock skew).
// 2. Once we reach 1h before expiry of the first certificate, we switch over to the second certificate.
// At the same time, we stop advertising the certhash of the first cert and generate the next cert.
type certManager struct {
clock clock.Clock
ctx context.Context
ctxCancel context.CancelFunc
refCount sync.WaitGroup

mx sync.Mutex
mx sync.RWMutex
lastConfig *certConfig // initially nil
currentConfig *certConfig
nextConfig *certConfig // nil until we have passed half the certValidity of the current config
Expand All @@ -61,64 +64,71 @@ func newCertManager(clock clock.Clock) (*certManager, error) {
return nil, err
}

t := m.clock.Ticker(certValidity * 4 / 9) // make sure we're a bit faster than 1/2
m.refCount.Add(1)
go func() {
defer m.refCount.Done()
defer t.Stop()
if err := m.background(t); err != nil {
log.Fatal(err)
}
}()
m.background()
return m, nil
}

func (m *certManager) init() error {
start := m.clock.Now()
end := start.Add(certValidity)
cc, err := newCertConfig(start, end)
start := m.clock.Now().Add(-clockSkewAllowance)
var err error
m.nextConfig, err = newCertConfig(start, start.Add(certValidity))
if err != nil {
return err
}
m.currentConfig = cc
return m.rollConfig()
}

func (m *certManager) rollConfig() error {
// We stop using the current certificate clockSkewAllowance before its expiry time.
// At this point, the next certificate needs to be valid for one clockSkewAllowance.
nextStart := m.nextConfig.End().Add(-2 * clockSkewAllowance)
c, err := newCertConfig(nextStart, nextStart.Add(certValidity))
if err != nil {
return err
}
m.lastConfig = m.currentConfig
m.currentConfig = m.nextConfig
m.nextConfig = c
return m.cacheAddrComponent()
}

func (m *certManager) background(t *clock.Ticker) error {
for {
select {
case <-m.ctx.Done():
return nil
case start := <-t.C:
end := start.Add(certValidity)
cc, err := newCertConfig(start, end)
if err != nil {
return err
}
m.mx.Lock()
if m.nextConfig != nil {
m.lastConfig = m.currentConfig
m.currentConfig = m.nextConfig
}
m.nextConfig = cc
if err := m.cacheAddrComponent(); err != nil {
func (m *certManager) background() {
d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(m.clock.Now())
log.Debugw("setting timer", "duration", d.String())
t := m.clock.Timer(d)
m.refCount.Add(1)

go func() {
defer m.refCount.Done()
defer t.Stop()

for {
select {
case <-m.ctx.Done():
return
case now := <-t.C:
m.mx.Lock()
if err := m.rollConfig(); err != nil {
log.Errorw("rolling config failed", "error", err)
}
d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(now)
log.Debugw("rolling certificates", "next", d.String())
t.Reset(d)
m.mx.Unlock()
return err
}
m.mx.Unlock()
}
}
}()
}

func (m *certManager) GetConfig() *tls.Config {
m.mx.Lock()
defer m.mx.Unlock()
m.mx.RLock()
defer m.mx.RUnlock()
return m.currentConfig.tlsConf
}

func (m *certManager) AddrComponent() ma.Multiaddr {
m.mx.Lock()
defer m.mx.Unlock()
m.mx.RLock()
defer m.mx.RUnlock()
return m.addrComp
}

Expand Down
53 changes: 36 additions & 17 deletions p2p/transport/webtransport/cert_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,15 @@ func TestInitialCert(t *testing.T) {
conf := m.GetConfig()
require.Len(t, conf.Certificates, 1)
cert := conf.Certificates[0]
require.Equal(t, cl.Now().UTC(), cert.Leaf.NotBefore)
require.Equal(t, cl.Now().Add(certValidity).UTC(), cert.Leaf.NotAfter)
require.Equal(t, cl.Now().Add(-clockSkewAllowance).UTC(), cert.Leaf.NotBefore)
require.Equal(t, cert.Leaf.NotBefore.Add(certValidity), cert.Leaf.NotAfter)
addr := m.AddrComponent()
components := splitMultiaddr(addr)
require.Len(t, components, 1)
require.Len(t, components, 2)
require.Equal(t, ma.P_CERTHASH, components[0].Protocol().Code)
hash := certificateHashFromTLSConfig(conf)
require.Equal(t, hash[:], certHashFromComponent(t, components[0]))
require.Equal(t, ma.P_CERTHASH, components[1].Protocol().Code)
}

func TestCertRenewal(t *testing.T) {
Expand All @@ -63,21 +64,39 @@ func TestCertRenewal(t *testing.T) {
defer m.Close()

firstConf := m.GetConfig()
require.Len(t, splitMultiaddr(m.AddrComponent()), 1)
first := splitMultiaddr(m.AddrComponent())
require.Len(t, first, 2)
require.NotEqual(t, first[0].Value(), first[1].Value(), "the hashes should differ")
// wait for a new certificate to be generated
cl.Add(certValidity / 2)
require.Eventually(t, func() bool { return len(splitMultiaddr(m.AddrComponent())) > 1 }, 200*time.Millisecond, 10*time.Millisecond)
// the actual config used should still be the same, we're just advertising the hash of the next config
components := splitMultiaddr(m.AddrComponent())
require.Len(t, components, 2)
for _, c := range components {
cl.Add(certValidity - 2*clockSkewAllowance - time.Second)
require.Never(t, func() bool {
for i, c := range splitMultiaddr(m.AddrComponent()) {
if c.Value() != first[i].Value() {
return true
}
}
return false
}, 100*time.Millisecond, 10*time.Millisecond)
cl.Add(2 * time.Second)
require.Eventually(t, func() bool { return m.GetConfig() != firstConf }, 200*time.Millisecond, 10*time.Millisecond)
secondConf := m.GetConfig()

second := splitMultiaddr(m.AddrComponent())
require.Len(t, second, 2)
for _, c := range second {
require.Equal(t, ma.P_CERTHASH, c.Protocol().Code)
}
require.Equal(t, firstConf, m.GetConfig())
cl.Add(certValidity / 2)
require.Eventually(t, func() bool { return m.GetConfig() != firstConf }, 200*time.Millisecond, 10*time.Millisecond)
newConf := m.GetConfig()
// check that the new config now matches the second component
hash := certificateHashFromTLSConfig(newConf)
require.Equal(t, hash[:], certHashFromComponent(t, components[1]))
// check that the 2nd certificate from the beginning was rolled over to be the 1st certificate
require.Equal(t, first[1].Value(), second[0].Value())
require.NotEqual(t, first[0].Value(), second[1].Value())

cl.Add(certValidity - 2*clockSkewAllowance + time.Second)
require.Eventually(t, func() bool { return m.GetConfig() != secondConf }, 200*time.Millisecond, 10*time.Millisecond)
third := splitMultiaddr(m.AddrComponent())
require.Len(t, third, 2)
for _, c := range third {
require.Equal(t, ma.P_CERTHASH, c.Protocol().Code)
}
// check that the 2nd certificate from the beginning was rolled over to be the 1st certificate
require.Equal(t, second[1].Value(), third[0].Value())
}
13 changes: 11 additions & 2 deletions p2p/transport/webtransport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,16 @@ func TestHashVerification(t *testing.T) {

t.Run("fails using only a wrong hash", func(t *testing.T) {
// replace the certificate hash in the multiaddr with a fake hash
addr, _ := ma.SplitLast(ln.Multiaddr())
addr := ln.Multiaddr()
// strip off all certhash components
for {
a, comp := ma.SplitLast(addr)
if comp.Protocol().Code != ma.P_CERTHASH {
break
}
addr = a
}

addr = addr.Encapsulate(foobarHash)

_, err := tr2.Dial(context.Background(), addr, serverID)
Expand Down Expand Up @@ -224,7 +233,7 @@ func TestListenerAddrs(t *testing.T) {
ln2, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport"))
require.NoError(t, err)
hashes1 := extractCertHashes(ln1.Multiaddr())
require.Len(t, hashes1, 1)
require.Len(t, hashes1, 2)
hashes2 := extractCertHashes(ln2.Multiaddr())
require.Equal(t, hashes1, hashes2)
}
Expand Down

0 comments on commit 2823159

Please sign in to comment.