From 2823159a99d45518a94db48f9942fa54005f2797 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 16 Jul 2022 10:22:10 +0000 Subject: [PATCH] optimize expiry periods of certificates (#21) --- p2p/transport/webtransport/cert_manager.go | 106 ++++++++++-------- .../webtransport/cert_manager_test.go | 53 ++++++--- p2p/transport/webtransport/transport_test.go | 13 ++- 3 files changed, 105 insertions(+), 67 deletions(-) diff --git a/p2p/transport/webtransport/cert_manager.go b/p2p/transport/webtransport/cert_manager.go index 519a89d88e..c46a60b666 100644 --- a/p2p/transport/webtransport/cert_manager.go +++ b/p2p/transport/webtransport/cert_manager.go @@ -15,11 +15,19 @@ import ( "github.com/multiformats/go-multihash" ) +// Allow for a bit of clock skew. +// When we generate a certificate, the NotBefore time is set to clockSkewAllowance before the current time. +// Similarly, we stop using a certificate one clockSkewAllowance before its expiry time. +const clockSkewAllowance = time.Hour + type certConfig struct { tlsConf *tls.Config sha256 [32]byte // cached from the tlsConf } +func (c *certConfig) Start() time.Time { return c.tlsConf.Certificates[0].Leaf.NotBefore } +func (c *certConfig) End() time.Time { return c.tlsConf.Certificates[0].Leaf.NotAfter } + func newCertConfig(start, end time.Time) (*certConfig, error) { conf, err := getTLSConf(start, end) if err != nil { @@ -32,22 +40,17 @@ func newCertConfig(start, end time.Time) (*certConfig, error) { } // Certificate renewal logic: -// 0. To simplify the math, assume the certificate is valid for 10 days (in real life: 14 days). -// 1. On startup, we generate the first certificate (1). -// 2. After 4 days, we generate a second certificate (2). -// We don't use that certificate yet, but we advertise the hashes of (1) and (2). -// That allows clients to connect to us using addresses that are 4 days old. -// 3. After another 4 days, we now actually start using (2). -// We also generate a third certificate (3), and start advertising the hashes of (2) and (3). -// We continue to remember the hash of (1) for validation during the Noise handshake for another 4 days, -// as the client might be connecting with a cached address. +// 1. On startup, we generate one cert that is valid from now (-1h, to allow for clock skew), and another +// cert that is valid from the expiry date of the first certificate (again, with allowance for clock skew). +// 2. Once we reach 1h before expiry of the first certificate, we switch over to the second certificate. +// At the same time, we stop advertising the certhash of the first cert and generate the next cert. type certManager struct { clock clock.Clock ctx context.Context ctxCancel context.CancelFunc refCount sync.WaitGroup - mx sync.Mutex + mx sync.RWMutex lastConfig *certConfig // initially nil currentConfig *certConfig nextConfig *certConfig // nil until we have passed half the certValidity of the current config @@ -61,64 +64,71 @@ func newCertManager(clock clock.Clock) (*certManager, error) { return nil, err } - t := m.clock.Ticker(certValidity * 4 / 9) // make sure we're a bit faster than 1/2 - m.refCount.Add(1) - go func() { - defer m.refCount.Done() - defer t.Stop() - if err := m.background(t); err != nil { - log.Fatal(err) - } - }() + m.background() return m, nil } func (m *certManager) init() error { - start := m.clock.Now() - end := start.Add(certValidity) - cc, err := newCertConfig(start, end) + start := m.clock.Now().Add(-clockSkewAllowance) + var err error + m.nextConfig, err = newCertConfig(start, start.Add(certValidity)) if err != nil { return err } - m.currentConfig = cc + return m.rollConfig() +} + +func (m *certManager) rollConfig() error { + // We stop using the current certificate clockSkewAllowance before its expiry time. + // At this point, the next certificate needs to be valid for one clockSkewAllowance. + nextStart := m.nextConfig.End().Add(-2 * clockSkewAllowance) + c, err := newCertConfig(nextStart, nextStart.Add(certValidity)) + if err != nil { + return err + } + m.lastConfig = m.currentConfig + m.currentConfig = m.nextConfig + m.nextConfig = c return m.cacheAddrComponent() } -func (m *certManager) background(t *clock.Ticker) error { - for { - select { - case <-m.ctx.Done(): - return nil - case start := <-t.C: - end := start.Add(certValidity) - cc, err := newCertConfig(start, end) - if err != nil { - return err - } - m.mx.Lock() - if m.nextConfig != nil { - m.lastConfig = m.currentConfig - m.currentConfig = m.nextConfig - } - m.nextConfig = cc - if err := m.cacheAddrComponent(); err != nil { +func (m *certManager) background() { + d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(m.clock.Now()) + log.Debugw("setting timer", "duration", d.String()) + t := m.clock.Timer(d) + m.refCount.Add(1) + + go func() { + defer m.refCount.Done() + defer t.Stop() + + for { + select { + case <-m.ctx.Done(): + return + case now := <-t.C: + m.mx.Lock() + if err := m.rollConfig(); err != nil { + log.Errorw("rolling config failed", "error", err) + } + d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(now) + log.Debugw("rolling certificates", "next", d.String()) + t.Reset(d) m.mx.Unlock() - return err } - m.mx.Unlock() } - } + }() } func (m *certManager) GetConfig() *tls.Config { - m.mx.Lock() - defer m.mx.Unlock() + m.mx.RLock() + defer m.mx.RUnlock() return m.currentConfig.tlsConf } func (m *certManager) AddrComponent() ma.Multiaddr { - m.mx.Lock() - defer m.mx.Unlock() + m.mx.RLock() + defer m.mx.RUnlock() return m.addrComp } diff --git a/p2p/transport/webtransport/cert_manager_test.go b/p2p/transport/webtransport/cert_manager_test.go index 69cd7163da..3f2328fbb7 100644 --- a/p2p/transport/webtransport/cert_manager_test.go +++ b/p2p/transport/webtransport/cert_manager_test.go @@ -46,14 +46,15 @@ func TestInitialCert(t *testing.T) { conf := m.GetConfig() require.Len(t, conf.Certificates, 1) cert := conf.Certificates[0] - require.Equal(t, cl.Now().UTC(), cert.Leaf.NotBefore) - require.Equal(t, cl.Now().Add(certValidity).UTC(), cert.Leaf.NotAfter) + require.Equal(t, cl.Now().Add(-clockSkewAllowance).UTC(), cert.Leaf.NotBefore) + require.Equal(t, cert.Leaf.NotBefore.Add(certValidity), cert.Leaf.NotAfter) addr := m.AddrComponent() components := splitMultiaddr(addr) - require.Len(t, components, 1) + require.Len(t, components, 2) require.Equal(t, ma.P_CERTHASH, components[0].Protocol().Code) hash := certificateHashFromTLSConfig(conf) require.Equal(t, hash[:], certHashFromComponent(t, components[0])) + require.Equal(t, ma.P_CERTHASH, components[1].Protocol().Code) } func TestCertRenewal(t *testing.T) { @@ -63,21 +64,39 @@ func TestCertRenewal(t *testing.T) { defer m.Close() firstConf := m.GetConfig() - require.Len(t, splitMultiaddr(m.AddrComponent()), 1) + first := splitMultiaddr(m.AddrComponent()) + require.Len(t, first, 2) + require.NotEqual(t, first[0].Value(), first[1].Value(), "the hashes should differ") // wait for a new certificate to be generated - cl.Add(certValidity / 2) - require.Eventually(t, func() bool { return len(splitMultiaddr(m.AddrComponent())) > 1 }, 200*time.Millisecond, 10*time.Millisecond) - // the actual config used should still be the same, we're just advertising the hash of the next config - components := splitMultiaddr(m.AddrComponent()) - require.Len(t, components, 2) - for _, c := range components { + cl.Add(certValidity - 2*clockSkewAllowance - time.Second) + require.Never(t, func() bool { + for i, c := range splitMultiaddr(m.AddrComponent()) { + if c.Value() != first[i].Value() { + return true + } + } + return false + }, 100*time.Millisecond, 10*time.Millisecond) + cl.Add(2 * time.Second) + require.Eventually(t, func() bool { return m.GetConfig() != firstConf }, 200*time.Millisecond, 10*time.Millisecond) + secondConf := m.GetConfig() + + second := splitMultiaddr(m.AddrComponent()) + require.Len(t, second, 2) + for _, c := range second { require.Equal(t, ma.P_CERTHASH, c.Protocol().Code) } - require.Equal(t, firstConf, m.GetConfig()) - cl.Add(certValidity / 2) - require.Eventually(t, func() bool { return m.GetConfig() != firstConf }, 200*time.Millisecond, 10*time.Millisecond) - newConf := m.GetConfig() - // check that the new config now matches the second component - hash := certificateHashFromTLSConfig(newConf) - require.Equal(t, hash[:], certHashFromComponent(t, components[1])) + // check that the 2nd certificate from the beginning was rolled over to be the 1st certificate + require.Equal(t, first[1].Value(), second[0].Value()) + require.NotEqual(t, first[0].Value(), second[1].Value()) + + cl.Add(certValidity - 2*clockSkewAllowance + time.Second) + require.Eventually(t, func() bool { return m.GetConfig() != secondConf }, 200*time.Millisecond, 10*time.Millisecond) + third := splitMultiaddr(m.AddrComponent()) + require.Len(t, third, 2) + for _, c := range third { + require.Equal(t, ma.P_CERTHASH, c.Protocol().Code) + } + // check that the 2nd certificate from the beginning was rolled over to be the 1st certificate + require.Equal(t, second[1].Value(), third[0].Value()) } diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index d85fcb828f..61cd245138 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -140,7 +140,16 @@ func TestHashVerification(t *testing.T) { t.Run("fails using only a wrong hash", func(t *testing.T) { // replace the certificate hash in the multiaddr with a fake hash - addr, _ := ma.SplitLast(ln.Multiaddr()) + addr := ln.Multiaddr() + // strip off all certhash components + for { + a, comp := ma.SplitLast(addr) + if comp.Protocol().Code != ma.P_CERTHASH { + break + } + addr = a + } + addr = addr.Encapsulate(foobarHash) _, err := tr2.Dial(context.Background(), addr, serverID) @@ -224,7 +233,7 @@ func TestListenerAddrs(t *testing.T) { ln2, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) require.NoError(t, err) hashes1 := extractCertHashes(ln1.Multiaddr()) - require.Len(t, hashes1, 1) + require.Len(t, hashes1, 2) hashes2 := extractCertHashes(ln2.Multiaddr()) require.Equal(t, hashes1, hashes2) }