Browse Source

optimize expiry periods of certificates (#21)

pull/1737/head
Marten Seemann 2 years ago
committed by GitHub
parent
commit
2823159a99
  1. 106
      p2p/transport/webtransport/cert_manager.go
  2. 53
      p2p/transport/webtransport/cert_manager_test.go
  3. 13
      p2p/transport/webtransport/transport_test.go

106
p2p/transport/webtransport/cert_manager.go

@ -15,11 +15,19 @@ import (
"github.com/multiformats/go-multihash" "github.com/multiformats/go-multihash"
) )
// Allow for a bit of clock skew.
// When we generate a certificate, the NotBefore time is set to clockSkewAllowance before the current time.
// Similarly, we stop using a certificate one clockSkewAllowance before its expiry time.
const clockSkewAllowance = time.Hour
type certConfig struct { type certConfig struct {
tlsConf *tls.Config tlsConf *tls.Config
sha256 [32]byte // cached from the tlsConf sha256 [32]byte // cached from the tlsConf
} }
func (c *certConfig) Start() time.Time { return c.tlsConf.Certificates[0].Leaf.NotBefore }
func (c *certConfig) End() time.Time { return c.tlsConf.Certificates[0].Leaf.NotAfter }
func newCertConfig(start, end time.Time) (*certConfig, error) { func newCertConfig(start, end time.Time) (*certConfig, error) {
conf, err := getTLSConf(start, end) conf, err := getTLSConf(start, end)
if err != nil { if err != nil {
@ -32,22 +40,17 @@ func newCertConfig(start, end time.Time) (*certConfig, error) {
} }
// Certificate renewal logic: // Certificate renewal logic:
// 0. To simplify the math, assume the certificate is valid for 10 days (in real life: 14 days). // 1. On startup, we generate one cert that is valid from now (-1h, to allow for clock skew), and another
// 1. On startup, we generate the first certificate (1). // cert that is valid from the expiry date of the first certificate (again, with allowance for clock skew).
// 2. After 4 days, we generate a second certificate (2). // 2. Once we reach 1h before expiry of the first certificate, we switch over to the second certificate.
// We don't use that certificate yet, but we advertise the hashes of (1) and (2). // At the same time, we stop advertising the certhash of the first cert and generate the next cert.
// That allows clients to connect to us using addresses that are 4 days old.
// 3. After another 4 days, we now actually start using (2).
// We also generate a third certificate (3), and start advertising the hashes of (2) and (3).
// We continue to remember the hash of (1) for validation during the Noise handshake for another 4 days,
// as the client might be connecting with a cached address.
type certManager struct { type certManager struct {
clock clock.Clock clock clock.Clock
ctx context.Context ctx context.Context
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
refCount sync.WaitGroup refCount sync.WaitGroup
mx sync.Mutex mx sync.RWMutex
lastConfig *certConfig // initially nil lastConfig *certConfig // initially nil
currentConfig *certConfig currentConfig *certConfig
nextConfig *certConfig // nil until we have passed half the certValidity of the current config nextConfig *certConfig // nil until we have passed half the certValidity of the current config
@ -61,64 +64,71 @@ func newCertManager(clock clock.Clock) (*certManager, error) {
return nil, err return nil, err
} }
t := m.clock.Ticker(certValidity * 4 / 9) // make sure we're a bit faster than 1/2 m.background()
m.refCount.Add(1)
go func() {
defer m.refCount.Done()
defer t.Stop()
if err := m.background(t); err != nil {
log.Fatal(err)
}
}()
return m, nil return m, nil
} }
func (m *certManager) init() error { func (m *certManager) init() error {
start := m.clock.Now() start := m.clock.Now().Add(-clockSkewAllowance)
end := start.Add(certValidity) var err error
cc, err := newCertConfig(start, end) m.nextConfig, err = newCertConfig(start, start.Add(certValidity))
if err != nil { if err != nil {
return err return err
} }
m.currentConfig = cc return m.rollConfig()
}
func (m *certManager) rollConfig() error {
// We stop using the current certificate clockSkewAllowance before its expiry time.
// At this point, the next certificate needs to be valid for one clockSkewAllowance.
nextStart := m.nextConfig.End().Add(-2 * clockSkewAllowance)
c, err := newCertConfig(nextStart, nextStart.Add(certValidity))
if err != nil {
return err
}
m.lastConfig = m.currentConfig
m.currentConfig = m.nextConfig
m.nextConfig = c
return m.cacheAddrComponent() return m.cacheAddrComponent()
} }
func (m *certManager) background(t *clock.Ticker) error { func (m *certManager) background() {
for { d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(m.clock.Now())
select { log.Debugw("setting timer", "duration", d.String())
case <-m.ctx.Done(): t := m.clock.Timer(d)
return nil m.refCount.Add(1)
case start := <-t.C:
end := start.Add(certValidity) go func() {
cc, err := newCertConfig(start, end) defer m.refCount.Done()
if err != nil { defer t.Stop()
return err
} for {
m.mx.Lock() select {
if m.nextConfig != nil { case <-m.ctx.Done():
m.lastConfig = m.currentConfig return
m.currentConfig = m.nextConfig case now := <-t.C:
} m.mx.Lock()
m.nextConfig = cc if err := m.rollConfig(); err != nil {
if err := m.cacheAddrComponent(); err != nil { log.Errorw("rolling config failed", "error", err)
}
d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(now)
log.Debugw("rolling certificates", "next", d.String())
t.Reset(d)
m.mx.Unlock() m.mx.Unlock()
return err
} }
m.mx.Unlock()
} }
} }()
} }
func (m *certManager) GetConfig() *tls.Config { func (m *certManager) GetConfig() *tls.Config {
m.mx.Lock() m.mx.RLock()
defer m.mx.Unlock() defer m.mx.RUnlock()
return m.currentConfig.tlsConf return m.currentConfig.tlsConf
} }
func (m *certManager) AddrComponent() ma.Multiaddr { func (m *certManager) AddrComponent() ma.Multiaddr {
m.mx.Lock() m.mx.RLock()
defer m.mx.Unlock() defer m.mx.RUnlock()
return m.addrComp return m.addrComp
} }

53
p2p/transport/webtransport/cert_manager_test.go

@ -46,14 +46,15 @@ func TestInitialCert(t *testing.T) {
conf := m.GetConfig() conf := m.GetConfig()
require.Len(t, conf.Certificates, 1) require.Len(t, conf.Certificates, 1)
cert := conf.Certificates[0] cert := conf.Certificates[0]
require.Equal(t, cl.Now().UTC(), cert.Leaf.NotBefore) require.Equal(t, cl.Now().Add(-clockSkewAllowance).UTC(), cert.Leaf.NotBefore)
require.Equal(t, cl.Now().Add(certValidity).UTC(), cert.Leaf.NotAfter) require.Equal(t, cert.Leaf.NotBefore.Add(certValidity), cert.Leaf.NotAfter)
addr := m.AddrComponent() addr := m.AddrComponent()
components := splitMultiaddr(addr) components := splitMultiaddr(addr)
require.Len(t, components, 1) require.Len(t, components, 2)
require.Equal(t, ma.P_CERTHASH, components[0].Protocol().Code) require.Equal(t, ma.P_CERTHASH, components[0].Protocol().Code)
hash := certificateHashFromTLSConfig(conf) hash := certificateHashFromTLSConfig(conf)
require.Equal(t, hash[:], certHashFromComponent(t, components[0])) require.Equal(t, hash[:], certHashFromComponent(t, components[0]))
require.Equal(t, ma.P_CERTHASH, components[1].Protocol().Code)
} }
func TestCertRenewal(t *testing.T) { func TestCertRenewal(t *testing.T) {
@ -63,21 +64,39 @@ func TestCertRenewal(t *testing.T) {
defer m.Close() defer m.Close()
firstConf := m.GetConfig() firstConf := m.GetConfig()
require.Len(t, splitMultiaddr(m.AddrComponent()), 1) first := splitMultiaddr(m.AddrComponent())
require.Len(t, first, 2)
require.NotEqual(t, first[0].Value(), first[1].Value(), "the hashes should differ")
// wait for a new certificate to be generated // wait for a new certificate to be generated
cl.Add(certValidity / 2) cl.Add(certValidity - 2*clockSkewAllowance - time.Second)
require.Eventually(t, func() bool { return len(splitMultiaddr(m.AddrComponent())) > 1 }, 200*time.Millisecond, 10*time.Millisecond) require.Never(t, func() bool {
// the actual config used should still be the same, we're just advertising the hash of the next config for i, c := range splitMultiaddr(m.AddrComponent()) {
components := splitMultiaddr(m.AddrComponent()) if c.Value() != first[i].Value() {
require.Len(t, components, 2) return true
for _, c := range components { }
}
return false
}, 100*time.Millisecond, 10*time.Millisecond)
cl.Add(2 * time.Second)
require.Eventually(t, func() bool { return m.GetConfig() != firstConf }, 200*time.Millisecond, 10*time.Millisecond)
secondConf := m.GetConfig()
second := splitMultiaddr(m.AddrComponent())
require.Len(t, second, 2)
for _, c := range second {
require.Equal(t, ma.P_CERTHASH, c.Protocol().Code) require.Equal(t, ma.P_CERTHASH, c.Protocol().Code)
} }
require.Equal(t, firstConf, m.GetConfig()) // check that the 2nd certificate from the beginning was rolled over to be the 1st certificate
cl.Add(certValidity / 2) require.Equal(t, first[1].Value(), second[0].Value())
require.Eventually(t, func() bool { return m.GetConfig() != firstConf }, 200*time.Millisecond, 10*time.Millisecond) require.NotEqual(t, first[0].Value(), second[1].Value())
newConf := m.GetConfig()
// check that the new config now matches the second component cl.Add(certValidity - 2*clockSkewAllowance + time.Second)
hash := certificateHashFromTLSConfig(newConf) require.Eventually(t, func() bool { return m.GetConfig() != secondConf }, 200*time.Millisecond, 10*time.Millisecond)
require.Equal(t, hash[:], certHashFromComponent(t, components[1])) third := splitMultiaddr(m.AddrComponent())
require.Len(t, third, 2)
for _, c := range third {
require.Equal(t, ma.P_CERTHASH, c.Protocol().Code)
}
// check that the 2nd certificate from the beginning was rolled over to be the 1st certificate
require.Equal(t, second[1].Value(), third[0].Value())
} }

13
p2p/transport/webtransport/transport_test.go

@ -140,7 +140,16 @@ func TestHashVerification(t *testing.T) {
t.Run("fails using only a wrong hash", func(t *testing.T) { t.Run("fails using only a wrong hash", func(t *testing.T) {
// replace the certificate hash in the multiaddr with a fake hash // replace the certificate hash in the multiaddr with a fake hash
addr, _ := ma.SplitLast(ln.Multiaddr()) addr := ln.Multiaddr()
// strip off all certhash components
for {
a, comp := ma.SplitLast(addr)
if comp.Protocol().Code != ma.P_CERTHASH {
break
}
addr = a
}
addr = addr.Encapsulate(foobarHash) addr = addr.Encapsulate(foobarHash)
_, err := tr2.Dial(context.Background(), addr, serverID) _, err := tr2.Dial(context.Background(), addr, serverID)
@ -224,7 +233,7 @@ func TestListenerAddrs(t *testing.T) {
ln2, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) ln2, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport"))
require.NoError(t, err) require.NoError(t, err)
hashes1 := extractCertHashes(ln1.Multiaddr()) hashes1 := extractCertHashes(ln1.Multiaddr())
require.Len(t, hashes1, 1) require.Len(t, hashes1, 2)
hashes2 := extractCertHashes(ln2.Multiaddr()) hashes2 := extractCertHashes(ln2.Multiaddr())
require.Equal(t, hashes1, hashes2) require.Equal(t, hashes1, hashes2)
} }

Loading…
Cancel
Save