diff --git a/p2p/transport/webtransport/cert_manager.go b/p2p/transport/webtransport/cert_manager.go index c005d448b..a86e66876 100644 --- a/p2p/transport/webtransport/cert_manager.go +++ b/p2p/transport/webtransport/cert_manager.go @@ -1,7 +1,6 @@ package libp2pwebtransport import ( - "bytes" "context" "crypto/sha256" "crypto/tls" @@ -9,6 +8,8 @@ import ( "sync" "time" + pb "github.com/libp2p/go-libp2p/p2p/transport/webtransport/pb" + "github.com/benbjohnson/clock" ma "github.com/multiformats/go-multiaddr" "github.com/multiformats/go-multihash" @@ -54,6 +55,8 @@ type certManager struct { currentConfig *certConfig nextConfig *certConfig // nil until we have passed half the certValidity of the current config addrComp ma.Multiaddr + + protobuf []byte } func newCertManager(clock clock.Clock) (*certManager, error) { @@ -88,6 +91,9 @@ func (m *certManager) rollConfig() error { m.lastConfig = m.currentConfig m.currentConfig = m.nextConfig m.nextConfig = c + if err := m.cacheProtobuf(); err != nil { + return err + } return m.cacheAddrComponent() } @@ -131,17 +137,33 @@ func (m *certManager) AddrComponent() ma.Multiaddr { return m.addrComp } -func (m *certManager) Verify(hashes []multihash.DecodedMultihash) error { - for _, h := range hashes { - if h.Code != multihash.SHA2_256 { - return fmt.Errorf("expected SHA256 hash, got %d", h.Code) - } - if !bytes.Equal(h.Digest, m.currentConfig.sha256[:]) && - (m.nextConfig == nil || !bytes.Equal(h.Digest, m.nextConfig.sha256[:])) && - (m.lastConfig == nil || !bytes.Equal(h.Digest, m.lastConfig.sha256[:])) { - return fmt.Errorf("found unexpected hash: %+x", h.Digest) +func (m *certManager) Protobuf() []byte { + return m.protobuf +} + +func (m *certManager) cacheProtobuf() error { + hashes := make([][32]byte, 0, 3) + if m.lastConfig != nil { + hashes = append(hashes, m.lastConfig.sha256) + } + hashes = append(hashes, m.currentConfig.sha256) + if m.nextConfig != nil { + hashes = append(hashes, m.nextConfig.sha256) + } + + msg := pb.WebTransport{CertHashes: make([][]byte, 0, len(hashes))} + for _, certHash := range hashes { + h, err := multihash.Encode(certHash[:], multihash.SHA2_256) + if err != nil { + return fmt.Errorf("failed to encode certificate hash: %w", err) } + msg.CertHashes = append(msg.CertHashes, h) + } + msgBytes, err := msg.Marshal() + if err != nil { + return fmt.Errorf("failed to marshal WebTransport protobuf: %w", err) } + m.protobuf = msgBytes return nil } diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index ee95d8502..f768c81d8 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -9,17 +9,17 @@ import ( "net/http" "time" + pb "github.com/libp2p/go-libp2p/p2p/transport/webtransport/pb" + "github.com/libp2p/go-libp2p/core/connmgr" "github.com/libp2p/go-libp2p/core/network" tpt "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/security/noise" - pb "github.com/libp2p/go-libp2p/p2p/transport/webtransport/pb" "github.com/lucas-clemente/quic-go/http3" "github.com/marten-seemann/webtransport-go" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" - "github.com/multiformats/go-multihash" ) var errClosed = errors.New("closed") @@ -197,7 +197,19 @@ func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (* if err != nil { return nil, err } - n, err := l.noise.WithSessionOptions(noise.EarlyData(nil, newEarlyDataReceiver(l.checkEarlyData))) + var earlyData []byte + if l.isStaticTLSConf { + var msg pb.WebTransport + var err error + earlyData, err = msg.Marshal() + if err != nil { + return nil, err + } + } else { + earlyData = l.certManager.Protobuf() + } + + n, err := l.noise.WithSessionOptions(noise.EarlyData(nil, newEarlyDataSender(earlyData))) if err != nil { return nil, fmt.Errorf("failed to initialize Noise session: %w", err) } @@ -212,31 +224,6 @@ func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (* }, nil } -func (l *listener) checkEarlyData(b []byte) error { - var msg pb.WebTransport - if err := msg.Unmarshal(b); err != nil { - fmt.Println(1) - return fmt.Errorf("failed to unmarshal early data protobuf: %w", err) - } - - if l.isStaticTLSConf { - if len(msg.CertHashes) > 0 { - return errors.New("using static TLS config, didn't expect any certificate hashes") - } - return nil - } - - hashes := make([]multihash.DecodedMultihash, 0, len(msg.CertHashes)) - for _, h := range msg.CertHashes { - dh, err := multihash.Decode(h) - if err != nil { - return fmt.Errorf("failed to decode hash: %w", err) - } - hashes = append(hashes, *dh) - } - return l.certManager.Verify(hashes) -} - func (l *listener) Addr() net.Addr { return l.addr } diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 176642f15..60d33a5b7 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -1,9 +1,11 @@ package libp2pwebtransport import ( + "bytes" "context" "crypto/tls" "crypto/x509" + "errors" "fmt" "io" "sync" @@ -196,19 +198,27 @@ func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p p // Now run a Noise handshake (using early data) and send all the certificate hashes that we would have accepted. // The server will verify that it advertised all of these certificate hashes. - msg := pb.WebTransport{CertHashes: make([][]byte, 0, len(certHashes))} - for _, certHash := range certHashes { - h, err := multihash.Encode(certHash.Digest, certHash.Code) + var verified bool + n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataReceiver(func(b []byte) error { + decodedCertHashes, err := decodeCertHashesFromProtobuf(b) if err != nil { - return nil, fmt.Errorf("failed to encode certificate hash: %w", err) + return err } - msg.CertHashes = append(msg.CertHashes, h) - } - msgBytes, err := msg.Marshal() - if err != nil { - return nil, fmt.Errorf("failed to marshal WebTransport protobuf: %w", err) - } - n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataSender(msgBytes), nil)) + for _, sent := range certHashes { + var found bool + for _, rcvd := range decodedCertHashes { + if sent.Code == rcvd.Code && bytes.Equal(sent.Digest, rcvd.Digest) { + found = true + break + } + } + if !found { + return fmt.Errorf("missing cert hash: %v", sent) + } + } + verified = true + return nil + }), nil)) if err != nil { return nil, fmt.Errorf("failed to create Noise transport: %w", err) } @@ -216,12 +226,34 @@ func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p p if err != nil { return nil, err } + // The Noise handshake _should_ guarantee that our verification callback is called. + // Double-check just in case. + if !verified { + return nil, errors.New("didn't verify") + } return &connSecurityMultiaddrs{ ConnSecurity: c, ConnMultiaddrs: &connMultiaddrs{local: local, remote: remote}, }, nil } +func decodeCertHashesFromProtobuf(b []byte) ([]multihash.DecodedMultihash, error) { + var msg pb.WebTransport + if err := msg.Unmarshal(b); err != nil { + return nil, fmt.Errorf("failed to unmarshal early data protobuf: %w", err) + } + + hashes := make([]multihash.DecodedMultihash, 0, len(msg.CertHashes)) + for _, h := range msg.CertHashes { + dh, err := multihash.Decode(h) + if err != nil { + return nil, fmt.Errorf("failed to decode hash: %w", err) + } + hashes = append(hashes, *dh) + } + return hashes, nil +} + func (t *transport) CanDial(addr ma.Multiaddr) bool { var numHashes int ma.ForEach(addr, func(c ma.Component) bool { diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index 3001438d7..92680fafb 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -162,11 +162,8 @@ func TestHashVerification(t *testing.T) { }) t.Run("fails when adding a wrong hash", func(t *testing.T) { - conn, err := tr2.Dial(context.Background(), ln.Multiaddr().Encapsulate(foobarHash), serverID) - if err != nil { - _, err = conn.AcceptStream() - require.Error(t, err) - } + _, err := tr2.Dial(context.Background(), ln.Multiaddr().Encapsulate(foobarHash), serverID) + require.Error(t, err) }) require.NoError(t, ln.Close())