Browse Source

move go-libp2p-webtransport here

pull/1737/head
Marten Seemann 2 years ago
parent
commit
48341e4060
  1. 168
      p2p/transport/webtransport/cert_manager.go
  2. 102
      p2p/transport/webtransport/cert_manager_test.go
  3. 60
      p2p/transport/webtransport/conn.go
  4. 108
      p2p/transport/webtransport/crypto.go
  5. 132
      p2p/transport/webtransport/crypto_test.go
  6. 249
      p2p/transport/webtransport/listener.go
  7. 109
      p2p/transport/webtransport/mock_connection_gater_test.go
  8. 80
      p2p/transport/webtransport/multiaddr.go
  9. 76
      p2p/transport/webtransport/multiaddr_test.go
  10. 34
      p2p/transport/webtransport/noise_early_data.go
  11. 11
      p2p/transport/webtransport/pb/Makefile
  12. 315
      p2p/transport/webtransport/pb/webtransport.pb.go
  13. 5
      p2p/transport/webtransport/pb/webtransport.proto
  14. 71
      p2p/transport/webtransport/stream.go
  15. 271
      p2p/transport/webtransport/transport.go
  16. 545
      p2p/transport/webtransport/transport_test.go

168
p2p/transport/webtransport/cert_manager.go

@ -0,0 +1,168 @@
package libp2pwebtransport
import (
"bytes"
"context"
"crypto/sha256"
"crypto/tls"
"fmt"
"sync"
"time"
"github.com/benbjohnson/clock"
ma "github.com/multiformats/go-multiaddr"
"github.com/multiformats/go-multihash"
)
// Allow for a bit of clock skew.
// When we generate a certificate, the NotBefore time is set to clockSkewAllowance before the current time.
// Similarly, we stop using a certificate one clockSkewAllowance before its expiry time.
const clockSkewAllowance = time.Hour
type certConfig struct {
tlsConf *tls.Config
sha256 [32]byte // cached from the tlsConf
}
func (c *certConfig) Start() time.Time { return c.tlsConf.Certificates[0].Leaf.NotBefore }
func (c *certConfig) End() time.Time { return c.tlsConf.Certificates[0].Leaf.NotAfter }
func newCertConfig(start, end time.Time) (*certConfig, error) {
conf, err := getTLSConf(start, end)
if err != nil {
return nil, err
}
return &certConfig{
tlsConf: conf,
sha256: sha256.Sum256(conf.Certificates[0].Leaf.Raw),
}, nil
}
// Certificate renewal logic:
// 1. On startup, we generate one cert that is valid from now (-1h, to allow for clock skew), and another
// cert that is valid from the expiry date of the first certificate (again, with allowance for clock skew).
// 2. Once we reach 1h before expiry of the first certificate, we switch over to the second certificate.
// At the same time, we stop advertising the certhash of the first cert and generate the next cert.
type certManager struct {
clock clock.Clock
ctx context.Context
ctxCancel context.CancelFunc
refCount sync.WaitGroup
mx sync.RWMutex
lastConfig *certConfig // initially nil
currentConfig *certConfig
nextConfig *certConfig // nil until we have passed half the certValidity of the current config
addrComp ma.Multiaddr
}
func newCertManager(clock clock.Clock) (*certManager, error) {
m := &certManager{clock: clock}
m.ctx, m.ctxCancel = context.WithCancel(context.Background())
if err := m.init(); err != nil {
return nil, err
}
m.background()
return m, nil
}
func (m *certManager) init() error {
start := m.clock.Now().Add(-clockSkewAllowance)
var err error
m.nextConfig, err = newCertConfig(start, start.Add(certValidity))
if err != nil {
return err
}
return m.rollConfig()
}
func (m *certManager) rollConfig() error {
// We stop using the current certificate clockSkewAllowance before its expiry time.
// At this point, the next certificate needs to be valid for one clockSkewAllowance.
nextStart := m.nextConfig.End().Add(-2 * clockSkewAllowance)
c, err := newCertConfig(nextStart, nextStart.Add(certValidity))
if err != nil {
return err
}
m.lastConfig = m.currentConfig
m.currentConfig = m.nextConfig
m.nextConfig = c
return m.cacheAddrComponent()
}
func (m *certManager) background() {
d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(m.clock.Now())
log.Debugw("setting timer", "duration", d.String())
t := m.clock.Timer(d)
m.refCount.Add(1)
go func() {
defer m.refCount.Done()
defer t.Stop()
for {
select {
case <-m.ctx.Done():
return
case now := <-t.C:
m.mx.Lock()
if err := m.rollConfig(); err != nil {
log.Errorw("rolling config failed", "error", err)
}
d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(now)
log.Debugw("rolling certificates", "next", d.String())
t.Reset(d)
m.mx.Unlock()
}
}
}()
}
func (m *certManager) GetConfig() *tls.Config {
m.mx.RLock()
defer m.mx.RUnlock()
return m.currentConfig.tlsConf
}
func (m *certManager) AddrComponent() ma.Multiaddr {
m.mx.RLock()
defer m.mx.RUnlock()
return m.addrComp
}
func (m *certManager) Verify(hashes []multihash.DecodedMultihash) error {
for _, h := range hashes {
if h.Code != multihash.SHA2_256 {
return fmt.Errorf("expected SHA256 hash, got %d", h.Code)
}
if !bytes.Equal(h.Digest, m.currentConfig.sha256[:]) &&
(m.nextConfig == nil || !bytes.Equal(h.Digest, m.nextConfig.sha256[:])) &&
(m.lastConfig == nil || !bytes.Equal(h.Digest, m.lastConfig.sha256[:])) {
return fmt.Errorf("found unexpected hash: %+x", h.Digest)
}
}
return nil
}
func (m *certManager) cacheAddrComponent() error {
addr, err := addrComponentForCert(m.currentConfig.sha256[:])
if err != nil {
return err
}
if m.nextConfig != nil {
comp, err := addrComponentForCert(m.nextConfig.sha256[:])
if err != nil {
return err
}
addr = addr.Encapsulate(comp)
}
m.addrComp = addr
return nil
}
func (m *certManager) Close() error {
m.ctxCancel()
m.refCount.Wait()
return nil
}

102
p2p/transport/webtransport/cert_manager_test.go

@ -0,0 +1,102 @@
package libp2pwebtransport
import (
"crypto/sha256"
"crypto/tls"
"testing"
"time"
"github.com/benbjohnson/clock"
ma "github.com/multiformats/go-multiaddr"
"github.com/multiformats/go-multibase"
"github.com/multiformats/go-multihash"
"github.com/stretchr/testify/require"
)
func certificateHashFromTLSConfig(c *tls.Config) [32]byte {
return sha256.Sum256(c.Certificates[0].Certificate[0])
}
func splitMultiaddr(addr ma.Multiaddr) []ma.Component {
var components []ma.Component
ma.ForEach(addr, func(c ma.Component) bool {
components = append(components, c)
return true
})
return components
}
func certHashFromComponent(t *testing.T, comp ma.Component) []byte {
t.Helper()
_, data, err := multibase.Decode(comp.Value())
require.NoError(t, err)
mh, err := multihash.Decode(data)
require.NoError(t, err)
require.Equal(t, uint64(multihash.SHA2_256), mh.Code)
return mh.Digest
}
func TestInitialCert(t *testing.T) {
cl := clock.NewMock()
cl.Add(1234567 * time.Hour)
m, err := newCertManager(cl)
require.NoError(t, err)
defer m.Close()
conf := m.GetConfig()
require.Len(t, conf.Certificates, 1)
cert := conf.Certificates[0]
require.Equal(t, cl.Now().Add(-clockSkewAllowance).UTC(), cert.Leaf.NotBefore)
require.Equal(t, cert.Leaf.NotBefore.Add(certValidity), cert.Leaf.NotAfter)
addr := m.AddrComponent()
components := splitMultiaddr(addr)
require.Len(t, components, 2)
require.Equal(t, ma.P_CERTHASH, components[0].Protocol().Code)
hash := certificateHashFromTLSConfig(conf)
require.Equal(t, hash[:], certHashFromComponent(t, components[0]))
require.Equal(t, ma.P_CERTHASH, components[1].Protocol().Code)
}
func TestCertRenewal(t *testing.T) {
cl := clock.NewMock()
m, err := newCertManager(cl)
require.NoError(t, err)
defer m.Close()
firstConf := m.GetConfig()
first := splitMultiaddr(m.AddrComponent())
require.Len(t, first, 2)
require.NotEqual(t, first[0].Value(), first[1].Value(), "the hashes should differ")
// wait for a new certificate to be generated
cl.Add(certValidity - 2*clockSkewAllowance - time.Second)
require.Never(t, func() bool {
for i, c := range splitMultiaddr(m.AddrComponent()) {
if c.Value() != first[i].Value() {
return true
}
}
return false
}, 100*time.Millisecond, 10*time.Millisecond)
cl.Add(2 * time.Second)
require.Eventually(t, func() bool { return m.GetConfig() != firstConf }, 200*time.Millisecond, 10*time.Millisecond)
secondConf := m.GetConfig()
second := splitMultiaddr(m.AddrComponent())
require.Len(t, second, 2)
for _, c := range second {
require.Equal(t, ma.P_CERTHASH, c.Protocol().Code)
}
// check that the 2nd certificate from the beginning was rolled over to be the 1st certificate
require.Equal(t, first[1].Value(), second[0].Value())
require.NotEqual(t, first[0].Value(), second[1].Value())
cl.Add(certValidity - 2*clockSkewAllowance + time.Second)
require.Eventually(t, func() bool { return m.GetConfig() != secondConf }, 200*time.Millisecond, 10*time.Millisecond)
third := splitMultiaddr(m.AddrComponent())
require.Len(t, third, 2)
for _, c := range third {
require.Equal(t, ma.P_CERTHASH, c.Protocol().Code)
}
// check that the 2nd certificate from the beginning was rolled over to be the 1st certificate
require.Equal(t, second[1].Value(), third[0].Value())
}

60
p2p/transport/webtransport/conn.go

@ -0,0 +1,60 @@
package libp2pwebtransport
import (
"context"
"github.com/libp2p/go-libp2p/core/network"
tpt "github.com/libp2p/go-libp2p/core/transport"
"github.com/marten-seemann/webtransport-go"
ma "github.com/multiformats/go-multiaddr"
)
type connSecurityMultiaddrs struct {
network.ConnSecurity
network.ConnMultiaddrs
}
type connMultiaddrs struct {
local, remote ma.Multiaddr
}
var _ network.ConnMultiaddrs = &connMultiaddrs{}
func (c *connMultiaddrs) LocalMultiaddr() ma.Multiaddr { return c.local }
func (c *connMultiaddrs) RemoteMultiaddr() ma.Multiaddr { return c.remote }
type conn struct {
*connSecurityMultiaddrs
transport tpt.Transport
session *webtransport.Session
scope network.ConnScope
}
var _ tpt.CapableConn = &conn{}
func newConn(tr tpt.Transport, sess *webtransport.Session, sconn *connSecurityMultiaddrs, scope network.ConnScope) *conn {
return &conn{
connSecurityMultiaddrs: sconn,
transport: tr,
session: sess,
scope: scope,
}
}
func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
str, err := c.session.OpenStreamSync(ctx)
return &stream{str}, err
}
func (c *conn) AcceptStream() (network.MuxedStream, error) {
str, err := c.session.AcceptStream(context.Background())
return &stream{str}, err
}
func (c *conn) Close() error { return c.session.Close() }
func (c *conn) IsClosed() bool { return c.session.Context().Err() != nil }
func (c *conn) Scope() network.ConnScope { return c.scope }
func (c *conn) Transport() tpt.Transport { return c.transport }

108
p2p/transport/webtransport/crypto.go

@ -0,0 +1,108 @@
package libp2pwebtransport
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/binary"
"errors"
"fmt"
"math/big"
"time"
"github.com/multiformats/go-multihash"
)
func getTLSConf(start, end time.Time) (*tls.Config, error) {
cert, priv, err := generateCert(start, end)
if err != nil {
return nil, err
}
return &tls.Config{
Certificates: []tls.Certificate{{
Certificate: [][]byte{cert.Raw},
PrivateKey: priv,
Leaf: cert,
}},
}, nil
}
func generateCert(start, end time.Time) (*x509.Certificate, *ecdsa.PrivateKey, error) {
b := make([]byte, 8)
if _, err := rand.Read(b); err != nil {
return nil, nil, err
}
serial := int64(binary.BigEndian.Uint64(b))
if serial < 0 {
serial = -serial
}
certTempl := &x509.Certificate{
SerialNumber: big.NewInt(serial),
Subject: pkix.Name{},
NotBefore: start,
NotAfter: end,
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, nil, err
}
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &caPrivateKey.PublicKey, caPrivateKey)
if err != nil {
return nil, nil, err
}
ca, err := x509.ParseCertificate(caBytes)
if err != nil {
return nil, nil, err
}
return ca, caPrivateKey, nil
}
func verifyRawCerts(rawCerts [][]byte, certHashes []multihash.DecodedMultihash) error {
if len(rawCerts) < 1 {
return errors.New("no cert")
}
leaf := rawCerts[len(rawCerts)-1]
// The W3C WebTransport specification currently only allows SHA-256 certificates for serverCertificateHashes.
hash := sha256.Sum256(leaf)
var verified bool
for _, h := range certHashes {
if h.Code == multihash.SHA2_256 && bytes.Equal(h.Digest, hash[:]) {
verified = true
break
}
}
if !verified {
digests := make([][]byte, 0, len(certHashes))
for _, h := range certHashes {
digests = append(digests, h.Digest)
}
return fmt.Errorf("cert hash not found: %#x (expected: %#x)", hash, digests)
}
cert, err := x509.ParseCertificate(leaf)
if err != nil {
return err
}
// TODO: is this the best (and complete?) way to identify RSA certificates?
switch cert.SignatureAlgorithm {
case x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA, x509.MD2WithRSA, x509.MD5WithRSA:
return errors.New("cert uses RSA")
}
if l := cert.NotAfter.Sub(cert.NotBefore); l > 14*24*time.Hour {
return fmt.Errorf("cert must not be valid for longer than 14 days (NotBefore: %s, NotAfter: %s, Length: %s)", cert.NotBefore, cert.NotAfter, l)
}
now := time.Now()
if now.Before(cert.NotBefore) || now.After(cert.NotAfter) {
return fmt.Errorf("cert not valid (NotBefore: %s, NotAfter: %s)", cert.NotBefore, cert.NotAfter)
}
return nil
}

132
p2p/transport/webtransport/crypto_test.go

@ -0,0 +1,132 @@
package libp2pwebtransport
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"math/big"
mrand "math/rand"
"testing"
"time"
"github.com/multiformats/go-multihash"
"github.com/stretchr/testify/require"
)
func sha256Multihash(t *testing.T, b []byte) multihash.DecodedMultihash {
t.Helper()
hash := sha256.Sum256(b)
h, err := multihash.Encode(hash[:], multihash.SHA2_256)
require.NoError(t, err)
dh, err := multihash.Decode(h)
require.NoError(t, err)
return *dh
}
func generateCertWithKey(t *testing.T, key crypto.PrivateKey, start, end time.Time) *x509.Certificate {
t.Helper()
serial := int64(mrand.Uint64())
if serial < 0 {
serial = -serial
}
certTempl := &x509.Certificate{
SerialNumber: big.NewInt(serial),
Subject: pkix.Name{},
NotBefore: start,
NotAfter: end,
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, key.(interface{ Public() crypto.PublicKey }).Public(), key)
require.NoError(t, err)
ca, err := x509.ParseCertificate(caBytes)
require.NoError(t, err)
return ca
}
func TestCertificateVerification(t *testing.T) {
now := time.Now()
ecdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
rsaKey, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(t, err)
t.Run("accepting a valid cert", func(t *testing.T) {
validCert := generateCertWithKey(t, ecdsaKey, now, now.Add(14*24*time.Hour))
require.NoError(t, verifyRawCerts([][]byte{validCert.Raw}, []multihash.DecodedMultihash{sha256Multihash(t, validCert.Raw)}))
})
for _, tc := range [...]struct {
name string
cert *x509.Certificate
errStr string
}{
{
name: "validitity period too long",
cert: generateCertWithKey(t, ecdsaKey, now, now.Add(15*24*time.Hour)),
errStr: "cert must not be valid for longer than 14 days",
},
{
name: "uses RSA key",
cert: generateCertWithKey(t, rsaKey, now, now.Add(14*24*time.Hour)),
errStr: "RSA",
},
{
name: "expired certificate",
cert: generateCertWithKey(t, ecdsaKey, now.Add(-14*24*time.Hour), now),
errStr: "cert not valid",
},
{
name: "not yet valid",
cert: generateCertWithKey(t, ecdsaKey, now.Add(time.Hour), now.Add(time.Hour+14*24*time.Hour)),
errStr: "cert not valid",
},
} {
tc := tc
t.Run(fmt.Sprintf("rejecting invalid certificates: %s", tc.name), func(t *testing.T) {
err := verifyRawCerts([][]byte{tc.cert.Raw}, []multihash.DecodedMultihash{sha256Multihash(t, tc.cert.Raw)})
require.Error(t, err)
require.Contains(t, err.Error(), tc.errStr)
})
}
for _, tc := range [...]struct {
name string
certs [][]byte
hashes []multihash.DecodedMultihash
errStr string
}{
{
name: "no certificates",
hashes: []multihash.DecodedMultihash{sha256Multihash(t, []byte("foobar"))},
errStr: "no cert",
},
{
name: "certificate not parseable",
certs: [][]byte{[]byte("foobar")},
hashes: []multihash.DecodedMultihash{sha256Multihash(t, []byte("foobar"))},
errStr: "x509: malformed certificate",
},
{
name: "hash mismatch",
certs: [][]byte{generateCertWithKey(t, ecdsaKey, now, now.Add(15*24*time.Hour)).Raw},
hashes: []multihash.DecodedMultihash{sha256Multihash(t, []byte("foobar"))},
errStr: "cert hash not found",
},
} {
tc := tc
t.Run(fmt.Sprintf("rejecting invalid certificates: %s", tc.name), func(t *testing.T) {
err := verifyRawCerts(tc.certs, tc.hashes)
require.Error(t, err)
require.Contains(t, err.Error(), tc.errStr)
})
}
}

249
p2p/transport/webtransport/listener.go

@ -0,0 +1,249 @@
package libp2pwebtransport
import (
"context"
"crypto/tls"
"errors"
"fmt"
pb "github.com/marten-seemann/go-libp2p-webtransport/pb"
"github.com/multiformats/go-multihash"
"net"
"net/http"
"time"
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/network"
tpt "github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/security/noise"
"github.com/lucas-clemente/quic-go/http3"
"github.com/marten-seemann/webtransport-go"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
var errClosed = errors.New("closed")
const queueLen = 16
const handshakeTimeout = 10 * time.Second
type listener struct {
transport tpt.Transport
noise *noise.Transport
certManager *certManager
tlsConf *tls.Config
isStaticTLSConf bool
rcmgr network.ResourceManager
gater connmgr.ConnectionGater
server webtransport.Server
ctx context.Context
ctxCancel context.CancelFunc
serverClosed chan struct{} // is closed when server.Serve returns
addr net.Addr
multiaddr ma.Multiaddr
queue chan tpt.CapableConn
}
var _ tpt.Listener = &listener{}
func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Transport, certManager *certManager, tlsConf *tls.Config, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) (tpt.Listener, error) {
network, addr, err := manet.DialArgs(laddr)
if err != nil {
return nil, err
}
udpAddr, err := net.ResolveUDPAddr(network, addr)
if err != nil {
return nil, err
}
udpConn, err := net.ListenUDP(network, udpAddr)
if err != nil {
return nil, err
}
localMultiaddr, err := toWebtransportMultiaddr(udpConn.LocalAddr())
if err != nil {
return nil, err
}
isStaticTLSConf := tlsConf != nil
if tlsConf == nil {
tlsConf = &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
return certManager.GetConfig(), nil
}}
}
ln := &listener{
transport: transport,
noise: noise,
certManager: certManager,
tlsConf: tlsConf,
isStaticTLSConf: isStaticTLSConf,
rcmgr: rcmgr,
gater: gater,
queue: make(chan tpt.CapableConn, queueLen),
serverClosed: make(chan struct{}),
addr: udpConn.LocalAddr(),
multiaddr: localMultiaddr,
server: webtransport.Server{H3: http3.Server{TLSConfig: tlsConf}},
}
ln.ctx, ln.ctxCancel = context.WithCancel(context.Background())
mux := http.NewServeMux()
mux.HandleFunc(webtransportHTTPEndpoint, ln.httpHandler)
ln.server.H3.Handler = mux
go func() {
defer close(ln.serverClosed)
defer func() { udpConn.Close() }()
if err := ln.server.Serve(udpConn); err != nil {
// TODO: only output if the server hasn't been closed
log.Debugw("serving failed", "addr", udpConn.LocalAddr(), "error", err)
}
}()
return ln, nil
}
func (l *listener) httpHandler(w http.ResponseWriter, r *http.Request) {
remoteMultiaddr, err := stringToWebtransportMultiaddr(r.RemoteAddr)
if err != nil {
// This should never happen.
log.Errorw("converting remote address failed", "remote", r.RemoteAddr, "error", err)
w.WriteHeader(http.StatusBadRequest)
return
}
if l.gater != nil && !l.gater.InterceptAccept(&connMultiaddrs{local: l.multiaddr, remote: remoteMultiaddr}) {
w.WriteHeader(http.StatusForbidden)
return
}
connScope, err := l.rcmgr.OpenConnection(network.DirInbound, false, remoteMultiaddr)
if err != nil {
log.Debugw("resource manager blocked incoming connection", "addr", r.RemoteAddr, "error", err)
w.WriteHeader(http.StatusServiceUnavailable)
return
}
// TODO: check ?type=multistream URL param
sess, err := l.server.Upgrade(w, r)
if err != nil {
log.Debugw("upgrade failed", "error", err)
// TODO: think about the status code to use here
w.WriteHeader(500)
connScope.Done()
return
}
ctx, cancel := context.WithTimeout(l.ctx, handshakeTimeout)
sconn, err := l.handshake(ctx, sess)
if err != nil {
cancel()
log.Debugw("handshake failed", "error", err)
sess.Close()
connScope.Done()
return
}
cancel()
if l.gater != nil && !l.gater.InterceptSecured(network.DirInbound, sconn.RemotePeer(), sconn) {
// TODO: can we close with a specific error here?
sess.Close()
connScope.Done()
return
}
if err := connScope.SetPeer(sconn.RemotePeer()); err != nil {
log.Debugw("resource manager blocked incoming connection for peer", "peer", sconn.RemotePeer(), "addr", r.RemoteAddr, "error", err)
sess.Close()
connScope.Done()
return
}
select {
case l.queue <- newConn(l.transport, sess, sconn, connScope):
default:
log.Debugw("accept queue full, dropping incoming connection", "peer", sconn.RemotePeer(), "addr", r.RemoteAddr, "error", err)
sess.Close()
connScope.Done()
}
}
func (l *listener) Accept() (tpt.CapableConn, error) {
select {
case <-l.ctx.Done():
return nil, errClosed
case c := <-l.queue:
return c, nil
}
}
func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (*connSecurityMultiaddrs, error) {
local, err := toWebtransportMultiaddr(sess.LocalAddr())
if err != nil {
return nil, fmt.Errorf("error determiniting local addr: %w", err)
}
remote, err := toWebtransportMultiaddr(sess.RemoteAddr())
if err != nil {
return nil, fmt.Errorf("error determiniting remote addr: %w", err)
}
str, err := sess.AcceptStream(ctx)
if err != nil {
return nil, err
}
n, err := l.noise.WithSessionOptions(noise.EarlyData(newEarlyDataReceiver(l.checkEarlyData)))
if err != nil {
return nil, fmt.Errorf("failed to initialize Noise session: %w", err)
}
c, err := n.SecureInbound(ctx, &webtransportStream{Stream: str, wsess: sess}, "")
if err != nil {
return nil, err
}
return &connSecurityMultiaddrs{
ConnSecurity: c,
ConnMultiaddrs: &connMultiaddrs{local: local, remote: remote},
}, nil
}
func (l *listener) checkEarlyData(b []byte) error {
var msg pb.WebTransport
if err := msg.Unmarshal(b); err != nil {
fmt.Println(1)
return fmt.Errorf("failed to unmarshal early data protobuf: %w", err)
}
if l.isStaticTLSConf {
if len(msg.CertHashes) > 0 {
return errors.New("using static TLS config, didn't expect any certificate hashes")
}
return nil
}
hashes := make([]multihash.DecodedMultihash, 0, len(msg.CertHashes))
for _, h := range msg.CertHashes {
dh, err := multihash.Decode(h)
if err != nil {
return fmt.Errorf("failed to decode hash: %w", err)
}
hashes = append(hashes, *dh)
}
return l.certManager.Verify(hashes)
}
func (l *listener) Addr() net.Addr {
return l.addr
}
func (l *listener) Multiaddr() ma.Multiaddr {
if l.certManager == nil {
return l.multiaddr
}
return l.multiaddr.Encapsulate(l.certManager.AddrComponent())
}
func (l *listener) Close() error {
l.ctxCancel()
err := l.server.Close()
<-l.serverClosed
return err
}

109
p2p/transport/webtransport/mock_connection_gater_test.go

@ -0,0 +1,109 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/libp2p/go-libp2p/core/connmgr (interfaces: ConnectionGater)
// Package libp2pwebtransport_test is a generated GoMock package.
package libp2pwebtransport_test
import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
control "github.com/libp2p/go-libp2p/core/control"
network "github.com/libp2p/go-libp2p/core/network"
peer "github.com/libp2p/go-libp2p/core/peer"
multiaddr "github.com/multiformats/go-multiaddr"
)
// MockConnectionGater is a mock of ConnectionGater interface.
type MockConnectionGater struct {
ctrl *gomock.Controller
recorder *MockConnectionGaterMockRecorder
}
// MockConnectionGaterMockRecorder is the mock recorder for MockConnectionGater.
type MockConnectionGaterMockRecorder struct {
mock *MockConnectionGater
}
// NewMockConnectionGater creates a new mock instance.
func NewMockConnectionGater(ctrl *gomock.Controller) *MockConnectionGater {
mock := &MockConnectionGater{ctrl: ctrl}
mock.recorder = &MockConnectionGaterMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockConnectionGater) EXPECT() *MockConnectionGaterMockRecorder {
return m.recorder
}
// InterceptAccept mocks base method.
func (m *MockConnectionGater) InterceptAccept(arg0 network.ConnMultiaddrs) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InterceptAccept", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// InterceptAccept indicates an expected call of InterceptAccept.
func (mr *MockConnectionGaterMockRecorder) InterceptAccept(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptAccept", reflect.TypeOf((*MockConnectionGater)(nil).InterceptAccept), arg0)
}
// InterceptAddrDial mocks base method.
func (m *MockConnectionGater) InterceptAddrDial(arg0 peer.ID, arg1 multiaddr.Multiaddr) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InterceptAddrDial", arg0, arg1)
ret0, _ := ret[0].(bool)
return ret0
}
// InterceptAddrDial indicates an expected call of InterceptAddrDial.
func (mr *MockConnectionGaterMockRecorder) InterceptAddrDial(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptAddrDial", reflect.TypeOf((*MockConnectionGater)(nil).InterceptAddrDial), arg0, arg1)
}
// InterceptPeerDial mocks base method.
func (m *MockConnectionGater) InterceptPeerDial(arg0 peer.ID) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InterceptPeerDial", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// InterceptPeerDial indicates an expected call of InterceptPeerDial.
func (mr *MockConnectionGaterMockRecorder) InterceptPeerDial(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptPeerDial", reflect.TypeOf((*MockConnectionGater)(nil).InterceptPeerDial), arg0)
}
// InterceptSecured mocks base method.
func (m *MockConnectionGater) InterceptSecured(arg0 network.Direction, arg1 peer.ID, arg2 network.ConnMultiaddrs) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InterceptSecured", arg0, arg1, arg2)
ret0, _ := ret[0].(bool)
return ret0
}
// InterceptSecured indicates an expected call of InterceptSecured.
func (mr *MockConnectionGaterMockRecorder) InterceptSecured(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptSecured", reflect.TypeOf((*MockConnectionGater)(nil).InterceptSecured), arg0, arg1, arg2)
}
// InterceptUpgraded mocks base method.
func (m *MockConnectionGater) InterceptUpgraded(arg0 network.Conn) (bool, control.DisconnectReason) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InterceptUpgraded", arg0)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(control.DisconnectReason)
return ret0, ret1
}
// InterceptUpgraded indicates an expected call of InterceptUpgraded.
func (mr *MockConnectionGaterMockRecorder) InterceptUpgraded(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptUpgraded", reflect.TypeOf((*MockConnectionGater)(nil).InterceptUpgraded), arg0)
}

80
p2p/transport/webtransport/multiaddr.go

@ -0,0 +1,80 @@
package libp2pwebtransport
import (
"errors"
"fmt"
"net"
"strconv"
ma "github.com/multiformats/go-multiaddr"
mafmt "github.com/multiformats/go-multiaddr-fmt"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/multiformats/go-multibase"
"github.com/multiformats/go-multihash"
)
var webtransportMA = ma.StringCast("/quic/webtransport")
var webtransportMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_UDP), mafmt.Base(ma.P_QUIC), mafmt.Base(ma.P_WEBTRANSPORT))
func toWebtransportMultiaddr(na net.Addr) (ma.Multiaddr, error) {
addr, err := manet.FromNetAddr(na)
if err != nil {
return nil, err
}
if _, err := addr.ValueForProtocol(ma.P_UDP); err != nil {
return nil, errors.New("not a UDP address")
}
return addr.Encapsulate(webtransportMA), nil
}
func stringToWebtransportMultiaddr(str string) (ma.Multiaddr, error) {
host, portStr, err := net.SplitHostPort(str)
if err != nil {
return nil, err
}
port, err := strconv.ParseInt(portStr, 10, 32)
if err != nil {
return nil, err
}
ip := net.ParseIP(host)
if ip == nil {
return nil, errors.New("failed to parse IP")
}
return toWebtransportMultiaddr(&net.UDPAddr{IP: ip, Port: int(port)})
}
func extractCertHashes(addr ma.Multiaddr) ([]multihash.DecodedMultihash, error) {
certHashesStr := make([]string, 0, 2)
ma.ForEach(addr, func(c ma.Component) bool {
if c.Protocol().Code == ma.P_CERTHASH {
certHashesStr = append(certHashesStr, c.Value())
}
return true
})
certHashes := make([]multihash.DecodedMultihash, 0, len(certHashesStr))
for _, s := range certHashesStr {
_, ch, err := multibase.Decode(s)
if err != nil {
return nil, fmt.Errorf("failed to multibase-decode certificate hash: %w", err)
}
dh, err := multihash.Decode(ch)
if err != nil {
return nil, fmt.Errorf("failed to multihash-decode certificate hash: %w", err)
}
certHashes = append(certHashes, *dh)
}
return certHashes, nil
}
func addrComponentForCert(hash []byte) (ma.Multiaddr, error) {
mh, err := multihash.Encode(hash, multihash.SHA2_256)
if err != nil {
return nil, err
}
certStr, err := multibase.Encode(multibase.Base58BTC, mh)
if err != nil {
return nil, err
}
return ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr)
}

76
p2p/transport/webtransport/multiaddr_test.go

@ -0,0 +1,76 @@
package libp2pwebtransport
import (
"fmt"
"net"
"testing"
ma "github.com/multiformats/go-multiaddr"
"github.com/multiformats/go-multibase"
"github.com/multiformats/go-multihash"
"github.com/stretchr/testify/require"
)
func TestWebtransportMultiaddr(t *testing.T) {
t.Run("valid", func(t *testing.T) {
addr, err := toWebtransportMultiaddr(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337})
require.NoError(t, err)
require.Equal(t, "/ip4/127.0.0.1/udp/1337/quic/webtransport", addr.String())
})
t.Run("invalid", func(t *testing.T) {
_, err := toWebtransportMultiaddr(&net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337})
require.EqualError(t, err, "not a UDP address")
})
}
func TestWebtransportMultiaddrFromString(t *testing.T) {
t.Run("valid", func(t *testing.T) {
addr, err := stringToWebtransportMultiaddr("1.2.3.4:60042")
require.NoError(t, err)
require.Equal(t, "/ip4/1.2.3.4/udp/60042/quic/webtransport", addr.String())
})
t.Run("invalid", func(t *testing.T) {
for _, addr := range [...]string{
"1.2.3.4", // missing port
"1.2.3.4:123456", // invalid port
":1234", // missing IP
"foobar",
} {
_, err := stringToWebtransportMultiaddr(addr)
require.Error(t, err)
}
})
}
func encodeCertHash(t *testing.T, b []byte, mh uint64, mb multibase.Encoding) string {
t.Helper()
h, err := multihash.Encode(b, mh)
require.NoError(t, err)
str, err := multibase.Encode(mb, h)
require.NoError(t, err)
return str
}
func TestExtractCertHashes(t *testing.T) {
fooHash := encodeCertHash(t, []byte("foo"), multihash.SHA2_256, multibase.Base58BTC)
barHash := encodeCertHash(t, []byte("bar"), multihash.BLAKE2B_MAX, multibase.Base32)
// valid cases
for _, tc := range [...]struct {
addr string
hashes []string
}{
{addr: "/ip4/127.0.0.1/udp/1234/quic/webtransport"},
{addr: fmt.Sprintf("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/%s", fooHash), hashes: []string{"foo"}},
{addr: fmt.Sprintf("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/%s/certhash/%s", fooHash, barHash), hashes: []string{"foo", "bar"}},
} {
ch, err := extractCertHashes(ma.StringCast(tc.addr))
require.NoError(t, err)
require.Len(t, ch, len(tc.hashes))
for i, h := range tc.hashes {
require.Equal(t, h, string(ch[i].Digest))
}
}
}

34
p2p/transport/webtransport/noise_early_data.go

@ -0,0 +1,34 @@
package libp2pwebtransport
import (
"context"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/p2p/security/noise"
"net"
)
type earlyDataHandler struct {
earlyData []byte
receive func([]byte) error
}
var _ noise.EarlyDataHandler = &earlyDataHandler{}
func newEarlyDataSender(earlyData []byte) noise.EarlyDataHandler {
return &earlyDataHandler{earlyData: earlyData}
}
func newEarlyDataReceiver(receive func([]byte) error) noise.EarlyDataHandler {
return &earlyDataHandler{receive: receive}
}
func (e *earlyDataHandler) Send(context.Context, net.Conn, peer.ID) []byte {
return e.earlyData
}
func (e *earlyDataHandler) Received(_ context.Context, _ net.Conn, b []byte) error {
if e.receive == nil {
return nil
}
return e.receive(b)
}

11
p2p/transport/webtransport/pb/Makefile

@ -0,0 +1,11 @@
PB = $(wildcard *.proto)
GO = $(PB:.proto=.pb.go)
all: $(GO)
%.pb.go: %.proto
protoc --proto_path=$(PWD)/../..:. --gogofaster_out=. $<
clean:
rm -f *.pb.go
rm -f *.go

315
p2p/transport/webtransport/pb/webtransport.pb.go

@ -0,0 +1,315 @@
// Code generated by protoc-gen-gogo. DO NOT EDIT.
// source: webtransport.proto
package webtransport
import (
fmt "fmt"
proto "github.com/gogo/protobuf/proto"
io "io"
math "math"
math_bits "math/bits"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package
type WebTransport struct {
CertHashes [][]byte `protobuf:"bytes,1,rep,name=cert_hashes,json=certHashes" json:"cert_hashes,omitempty"`
}
func (m *WebTransport) Reset() { *m = WebTransport{} }
func (m *WebTransport) String() string { return proto.CompactTextString(m) }
func (*WebTransport) ProtoMessage() {}
func (*WebTransport) Descriptor() ([]byte, []int) {
return fileDescriptor_db878920ab41a4f3, []int{0}
}
func (m *WebTransport) XXX_Unmarshal(b []byte) error {
return m.Unmarshal(b)
}
func (m *WebTransport) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
if deterministic {
return xxx_messageInfo_WebTransport.Marshal(b, m, deterministic)
} else {
b = b[:cap(b)]
n, err := m.MarshalToSizedBuffer(b)
if err != nil {
return nil, err
}
return b[:n], nil
}
}
func (m *WebTransport) XXX_Merge(src proto.Message) {
xxx_messageInfo_WebTransport.Merge(m, src)
}
func (m *WebTransport) XXX_Size() int {
return m.Size()
}
func (m *WebTransport) XXX_DiscardUnknown() {
xxx_messageInfo_WebTransport.DiscardUnknown(m)
}
var xxx_messageInfo_WebTransport proto.InternalMessageInfo
func (m *WebTransport) GetCertHashes() [][]byte {
if m != nil {
return m.CertHashes
}
return nil
}
func init() {
proto.RegisterType((*WebTransport)(nil), "WebTransport")
}
func init() { proto.RegisterFile("webtransport.proto", fileDescriptor_db878920ab41a4f3) }
var fileDescriptor_db878920ab41a4f3 = []byte{
// 109 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x2a, 0x4f, 0x4d, 0x2a,
0x29, 0x4a, 0xcc, 0x2b, 0x2e, 0xc8, 0x2f, 0x2a, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x57, 0xd2,
0xe7, 0xe2, 0x09, 0x4f, 0x4d, 0x0a, 0x81, 0x89, 0x0a, 0xc9, 0x73, 0x71, 0x27, 0xa7, 0x16, 0x95,
0xc4, 0x67, 0x24, 0x16, 0x67, 0xa4, 0x16, 0x4b, 0x30, 0x2a, 0x30, 0x6b, 0xf0, 0x04, 0x71, 0x81,
0x84, 0x3c, 0xc0, 0x22, 0x4e, 0x12, 0x27, 0x1e, 0xc9, 0x31, 0x5e, 0x78, 0x24, 0xc7, 0xf8, 0xe0,
0x91, 0x1c, 0xe3, 0x84, 0xc7, 0x72, 0x0c, 0x17, 0x1e, 0xcb, 0x31, 0xdc, 0x78, 0x2c, 0xc7, 0x00,
0x08, 0x00, 0x00, 0xff, 0xff, 0x50, 0x77, 0xe5, 0x52, 0x5f, 0x00, 0x00, 0x00,
}
func (m *WebTransport) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalToSizedBuffer(dAtA[:size])
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *WebTransport) MarshalTo(dAtA []byte) (int, error) {
size := m.Size()
return m.MarshalToSizedBuffer(dAtA[:size])
}
func (m *WebTransport) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i := len(dAtA)
_ = i
var l int
_ = l
if len(m.CertHashes) > 0 {
for iNdEx := len(m.CertHashes) - 1; iNdEx >= 0; iNdEx-- {
i -= len(m.CertHashes[iNdEx])
copy(dAtA[i:], m.CertHashes[iNdEx])
i = encodeVarintWebtransport(dAtA, i, uint64(len(m.CertHashes[iNdEx])))
i--
dAtA[i] = 0xa
}
}
return len(dAtA) - i, nil
}
func encodeVarintWebtransport(dAtA []byte, offset int, v uint64) int {
offset -= sovWebtransport(v)
base := offset
for v >= 1<<7 {
dAtA[offset] = uint8(v&0x7f | 0x80)
v >>= 7
offset++
}
dAtA[offset] = uint8(v)
return base
}
func (m *WebTransport) Size() (n int) {
if m == nil {
return 0
}
var l int
_ = l
if len(m.CertHashes) > 0 {
for _, b := range m.CertHashes {
l = len(b)
n += 1 + l + sovWebtransport(uint64(l))
}
}
return n
}
func sovWebtransport(x uint64) (n int) {
return (math_bits.Len64(x|1) + 6) / 7
}
func sozWebtransport(x uint64) (n int) {
return sovWebtransport(uint64((x << 1) ^ uint64((int64(x) >> 63))))
}
func (m *WebTransport) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowWebtransport
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: WebTransport: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: WebTransport: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field CertHashes", wireType)
}
var byteLen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowWebtransport
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
byteLen |= int(b&0x7F) << shift
if b < 0x80 {
break
}
}
if byteLen < 0 {
return ErrInvalidLengthWebtransport
}
postIndex := iNdEx + byteLen
if postIndex < 0 {
return ErrInvalidLengthWebtransport
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.CertHashes = append(m.CertHashes, make([]byte, postIndex-iNdEx))
copy(m.CertHashes[len(m.CertHashes)-1], dAtA[iNdEx:postIndex])
iNdEx = postIndex
default:
iNdEx = preIndex
skippy, err := skipWebtransport(dAtA[iNdEx:])
if err != nil {
return err
}
if (skippy < 0) || (iNdEx+skippy) < 0 {
return ErrInvalidLengthWebtransport
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func skipWebtransport(dAtA []byte) (n int, err error) {
l := len(dAtA)
iNdEx := 0
depth := 0
for iNdEx < l {
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowWebtransport
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
wireType := int(wire & 0x7)
switch wireType {
case 0:
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowWebtransport
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
iNdEx++
if dAtA[iNdEx-1] < 0x80 {
break
}
}
case 1:
iNdEx += 8
case 2:
var length int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowWebtransport
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
length |= (int(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
if length < 0 {
return 0, ErrInvalidLengthWebtransport
}
iNdEx += length
case 3:
depth++
case 4:
if depth == 0 {
return 0, ErrUnexpectedEndOfGroupWebtransport
}
depth--
case 5:
iNdEx += 4
default:
return 0, fmt.Errorf("proto: illegal wireType %d", wireType)
}
if iNdEx < 0 {
return 0, ErrInvalidLengthWebtransport
}
if depth == 0 {
return iNdEx, nil
}
}
return 0, io.ErrUnexpectedEOF
}
var (
ErrInvalidLengthWebtransport = fmt.Errorf("proto: negative length found during unmarshaling")
ErrIntOverflowWebtransport = fmt.Errorf("proto: integer overflow")
ErrUnexpectedEndOfGroupWebtransport = fmt.Errorf("proto: unexpected end of group")
)

5
p2p/transport/webtransport/pb/webtransport.proto

@ -0,0 +1,5 @@
syntax = "proto2";
message WebTransport {
repeated bytes cert_hashes = 1;
}

71
p2p/transport/webtransport/stream.go

@ -0,0 +1,71 @@
package libp2pwebtransport
import (
"errors"
"net"
"github.com/marten-seemann/webtransport-go"
"github.com/libp2p/go-libp2p/core/network"
)
const (
reset webtransport.ErrorCode = 0
)
type webtransportStream struct {
webtransport.Stream
wsess *webtransport.Session
}
var _ net.Conn = &webtransportStream{}
func (s *webtransportStream) LocalAddr() net.Addr {
return s.wsess.LocalAddr()
}
func (s *webtransportStream) RemoteAddr() net.Addr {
return s.wsess.RemoteAddr()
}
type stream struct {
webtransport.Stream
}
var _ network.MuxedStream = &stream{}
func (s *stream) Read(b []byte) (n int, err error) {
n, err = s.Stream.Read(b)
if err != nil && errors.Is(err, &webtransport.StreamError{}) {
err = network.ErrReset
}
return n, err
}
func (s *stream) Write(b []byte) (n int, err error) {
n, err = s.Stream.Write(b)
if err != nil && errors.Is(err, &webtransport.StreamError{}) {
err = network.ErrReset
}
return n, err
}
func (s *stream) Reset() error {
s.Stream.CancelRead(reset)
s.Stream.CancelWrite(reset)
return nil
}
func (s *stream) Close() error {
s.Stream.CancelRead(reset)
return s.Stream.Close()
}
func (s *stream) CloseRead() error {
s.Stream.CancelRead(reset)
return nil
}
func (s *stream) CloseWrite() error {
return s.Stream.Close()
}

271
p2p/transport/webtransport/transport.go

@ -0,0 +1,271 @@
package libp2pwebtransport
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"sync"
"time"
pb "github.com/marten-seemann/go-libp2p-webtransport/pb"
"github.com/libp2p/go-libp2p/core/connmgr"
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
tpt "github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/security/noise"
"github.com/benbjohnson/clock"
logging "github.com/ipfs/go-log/v2"
"github.com/lucas-clemente/quic-go/http3"
"github.com/marten-seemann/webtransport-go"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/multiformats/go-multihash"
)
var log = logging.Logger("webtransport")
const webtransportHTTPEndpoint = "/.well-known/libp2p-webtransport"
const certValidity = 14 * 24 * time.Hour
type Option func(*transport) error
func WithClock(cl clock.Clock) Option {
return func(t *transport) error {
t.clock = cl
return nil
}
}
// WithTLSConfig sets a tls.Config used for listening.
// When used, the certificate from that config will be used, and no /certhash will be added to the listener's multiaddr.
// This is most useful when running a listener that has a valid (CA-signed) certificate.
func WithTLSConfig(c *tls.Config) Option {
return func(t *transport) error {
t.staticTLSConf = c
return nil
}
}
// WithTLSClientConfig sets a custom tls.Config used for dialing.
// This option is most useful for setting a custom tls.Config.RootCAs certificate pool.
// When dialing a multiaddr that contains a /certhash component, this library will set InsecureSkipVerify and
// overwrite the VerifyPeerCertificate callback.
func WithTLSClientConfig(c *tls.Config) Option {
return func(t *transport) error {
t.tlsClientConf = c
return nil
}
}
type transport struct {
privKey ic.PrivKey
pid peer.ID
clock clock.Clock
rcmgr network.ResourceManager
gater connmgr.ConnectionGater
listenOnce sync.Once
listenOnceErr error
certManager *certManager
staticTLSConf *tls.Config
tlsClientConf *tls.Config
noise *noise.Transport
}
var _ tpt.Transport = &transport{}
var _ io.Closer = &transport{}
func New(key ic.PrivKey, gater connmgr.ConnectionGater, rcmgr network.ResourceManager, opts ...Option) (tpt.Transport, error) {
id, err := peer.IDFromPrivateKey(key)
if err != nil {
return nil, err
}
t := &transport{
pid: id,
privKey: key,
rcmgr: rcmgr,
gater: gater,
clock: clock.New(),
}
for _, opt := range opts {
if err := opt(t); err != nil {
return nil, err
}
}
n, err := noise.New(key)
if err != nil {
return nil, err
}
t.noise = n
return t, nil
}
func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) {
_, addr, err := manet.DialArgs(raddr)
if err != nil {
return nil, err
}
certHashes, err := extractCertHashes(raddr)
if err != nil {
return nil, err
}
scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false, raddr)
if err != nil {
log.Debugw("resource manager blocked outgoing connection", "peer", p, "addr", raddr, "error", err)
return nil, err
}
if err := scope.SetPeer(p); err != nil {
log.Debugw("resource manager blocked outgoing connection for peer", "peer", p, "addr", raddr, "error", err)
scope.Done()
return nil, err
}
sess, err := t.dial(ctx, addr, certHashes)
if err != nil {
scope.Done()
return nil, err
}
sconn, err := t.upgrade(ctx, sess, p, certHashes)
if err != nil {
sess.Close()
scope.Done()
return nil, err
}
if t.gater != nil && !t.gater.InterceptSecured(network.DirOutbound, p, sconn) {
// TODO: can we close with a specific error here?
sess.Close()
scope.Done()
return nil, fmt.Errorf("secured connection gated")
}
return newConn(t, sess, sconn, scope), nil
}
func (t *transport) dial(ctx context.Context, addr string, certHashes []multihash.DecodedMultihash) (*webtransport.Session, error) {
url := fmt.Sprintf("https://%s%s", addr, webtransportHTTPEndpoint)
var tlsConf *tls.Config
if t.tlsClientConf != nil {
tlsConf = t.tlsClientConf.Clone()
} else {
tlsConf = &tls.Config{}
}
if len(certHashes) > 0 {
// This is not insecure. We verify the certificate ourselves.
// See https://www.w3.org/TR/webtransport/#certificate-hashes.
tlsConf.InsecureSkipVerify = true
tlsConf.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
return verifyRawCerts(rawCerts, certHashes)
}
}
dialer := webtransport.Dialer{
RoundTripper: &http3.RoundTripper{TLSClientConfig: tlsConf},
}
rsp, sess, err := dialer.Dial(ctx, url, nil)
if err != nil {
return nil, err
}
if rsp.StatusCode < 200 || rsp.StatusCode > 299 {
return nil, fmt.Errorf("invalid response status code: %d", rsp.StatusCode)
}
return sess, err
}
func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash) (*connSecurityMultiaddrs, error) {
local, err := toWebtransportMultiaddr(sess.LocalAddr())
if err != nil {
return nil, fmt.Errorf("error determiniting local addr: %w", err)
}
remote, err := toWebtransportMultiaddr(sess.RemoteAddr())
if err != nil {
return nil, fmt.Errorf("error determiniting remote addr: %w", err)
}
str, err := sess.OpenStreamSync(ctx)
if err != nil {
return nil, err
}
// Now run a Noise handshake (using early data) and send all the certificate hashes that we would have accepted.
// The server will verify that it advertised all of these certificate hashes.
msg := pb.WebTransport{CertHashes: make([][]byte, 0, len(certHashes))}
for _, certHash := range certHashes {
h, err := multihash.Encode(certHash.Digest, certHash.Code)
if err != nil {
return nil, fmt.Errorf("failed to encode certificate hash: %w", err)
}
msg.CertHashes = append(msg.CertHashes, h)
}
msgBytes, err := msg.Marshal()
if err != nil {
return nil, fmt.Errorf("failed to marshal WebTransport protobuf: %w", err)
}
n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataSender(msgBytes)))
if err != nil {
return nil, fmt.Errorf("failed to create Noise transport: %w", err)
}
c, err := n.SecureOutbound(ctx, &webtransportStream{Stream: str, wsess: sess}, p)
if err != nil {
return nil, err
}
return &connSecurityMultiaddrs{
ConnSecurity: c,
ConnMultiaddrs: &connMultiaddrs{local: local, remote: remote},
}, nil
}
func (t *transport) CanDial(addr ma.Multiaddr) bool {
var numHashes int
ma.ForEach(addr, func(c ma.Component) bool {
if c.Protocol().Code == ma.P_CERTHASH {
numHashes++
}
return true
})
// Remove the /certhash components from the multiaddr.
// If the multiaddr doesn't contain any certhashes, the node might have a CA-signed certificate.
for i := 0; i < numHashes; i++ {
addr, _ = ma.SplitLast(addr)
}
return webtransportMatcher.Matches(addr)
}
func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) {
if !webtransportMatcher.Matches(laddr) {
return nil, fmt.Errorf("cannot listen on non-WebTransport addr: %s", laddr)
}
if t.staticTLSConf == nil {
t.listenOnce.Do(func() {
t.certManager, t.listenOnceErr = newCertManager(t.clock)
})
if t.listenOnceErr != nil {
return nil, t.listenOnceErr
}
}
return newListener(laddr, t, t.noise, t.certManager, t.staticTLSConf, t.gater, t.rcmgr)
}
func (t *transport) Protocols() []int {
return []int{ma.P_WEBTRANSPORT}
}
func (t *transport) Proxy() bool {
return false
}
func (t *transport) Close() error {
t.listenOnce.Do(func() {})
if t.certManager != nil {
return t.certManager.Close()
}
return nil
}

545
p2p/transport/webtransport/transport_test.go

@ -0,0 +1,545 @@
package libp2pwebtransport_test
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"fmt"
"io"
"math/big"
"net"
"strings"
"testing"
"time"
libp2pwebtransport "github.com/marten-seemann/go-libp2p-webtransport"
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
mocknetwork "github.com/libp2p/go-libp2p/core/network/mocks"
"github.com/libp2p/go-libp2p/core/peer"
tpt "github.com/libp2p/go-libp2p/core/transport"
"github.com/golang/mock/gomock"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/multiformats/go-multibase"
"github.com/multiformats/go-multihash"
"github.com/stretchr/testify/require"
)
func newIdentity(t *testing.T) (peer.ID, ic.PrivKey) {
key, _, err := ic.GenerateEd25519Key(rand.Reader)
require.NoError(t, err)
id, err := peer.IDFromPrivateKey(key)
require.NoError(t, err)
return id, key
}
func randomMultihash(t *testing.T) string {
t.Helper()
b := make([]byte, 16)
rand.Read(b)
h, err := multihash.Encode(b, multihash.KECCAK_224)
require.NoError(t, err)
s, err := multibase.Encode(multibase.Base32hex, h)
require.NoError(t, err)
return s
}
func extractCertHashes(addr ma.Multiaddr) []string {
var certHashesStr []string
ma.ForEach(addr, func(c ma.Component) bool {
if c.Protocol().Code == ma.P_CERTHASH {
certHashesStr = append(certHashesStr, c.Value())
}
return true
})
return certHashesStr
}
func stripCertHashes(addr ma.Multiaddr) ma.Multiaddr {
for {
_, err := addr.ValueForProtocol(ma.P_CERTHASH)
if err != nil {
return addr
}
addr, _ = ma.SplitLast(addr)
}
}
// create a /certhash multiaddr component using the SHA256 of foobar
func getCerthashComponent(t *testing.T, b []byte) ma.Multiaddr {
t.Helper()
h := sha256.Sum256(b)
mh, err := multihash.Encode(h[:], multihash.SHA2_256)
require.NoError(t, err)
certStr, err := multibase.Encode(multibase.Base58BTC, mh)
require.NoError(t, err)
ha, err := ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr)
require.NoError(t, err)
return ha
}
func TestTransport(t *testing.T) {
serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager)
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport"))
require.NoError(t, err)
defer ln.Close()
addrChan := make(chan ma.Multiaddr)
go func() {
_, clientKey := newIdentity(t)
tr2, err := libp2pwebtransport.New(clientKey, nil, network.NullResourceManager)
require.NoError(t, err)
defer tr2.(io.Closer).Close()
conn, err := tr2.Dial(context.Background(), ln.Multiaddr(), serverID)
require.NoError(t, err)
str, err := conn.OpenStream(context.Background())
require.NoError(t, err)
_, err = str.Write([]byte("foobar"))
require.NoError(t, err)
require.NoError(t, str.Close())
// check RemoteMultiaddr
_, addr, err := manet.DialArgs(ln.Multiaddr())
require.NoError(t, err)
_, port, err := net.SplitHostPort(addr)
require.NoError(t, err)
require.Equal(t, ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/%s/quic/webtransport", port)), conn.RemoteMultiaddr())
addrChan <- conn.RemoteMultiaddr()
}()
conn, err := ln.Accept()
require.NoError(t, err)
require.False(t, conn.IsClosed())
str, err := conn.AcceptStream()
require.NoError(t, err)
data, err := io.ReadAll(str)
require.NoError(t, err)
require.Equal(t, "foobar", string(data))
require.Equal(t, <-addrChan, conn.LocalMultiaddr())
require.NoError(t, conn.Close())
require.True(t, conn.IsClosed())
}
func TestHashVerification(t *testing.T) {
serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager)
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport"))
require.NoError(t, err)
done := make(chan struct{})
go func() {
defer close(done)
_, err := ln.Accept()
require.Error(t, err)
}()
_, clientKey := newIdentity(t)
tr2, err := libp2pwebtransport.New(clientKey, nil, network.NullResourceManager)
require.NoError(t, err)
defer tr2.(io.Closer).Close()
foobarHash := getCerthashComponent(t, []byte("foobar"))
t.Run("fails using only a wrong hash", func(t *testing.T) {
// replace the certificate hash in the multiaddr with a fake hash
addr := stripCertHashes(ln.Multiaddr()).Encapsulate(foobarHash)
_, err := tr2.Dial(context.Background(), addr, serverID)
require.Error(t, err)
require.Contains(t, err.Error(), "CRYPTO_ERROR (0x12a): cert hash not found")
})
t.Run("fails when adding a wrong hash", func(t *testing.T) {
_, err := tr2.Dial(context.Background(), ln.Multiaddr().Encapsulate(foobarHash), serverID)
require.Error(t, err)
})
require.NoError(t, ln.Close())
<-done
}
func TestCanDial(t *testing.T) {
valid := []ma.Multiaddr{
ma.StringCast("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/" + randomMultihash(t)),
ma.StringCast("/ip6/b16b:8255:efc6:9cd5:1a54:ee86:2d7a:c2e6/udp/1234/quic/webtransport/certhash/" + randomMultihash(t)),
ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/%s/certhash/%s/certhash/%s", randomMultihash(t), randomMultihash(t), randomMultihash(t))),
ma.StringCast("/ip4/127.0.0.1/udp/1234/quic/webtransport"), // no certificate hash
}
invalid := []ma.Multiaddr{
ma.StringCast("/ip4/127.0.0.1/udp/1234"), // missing webtransport
ma.StringCast("/ip4/127.0.0.1/udp/1234/webtransport"), // missing quic
ma.StringCast("/ip4/127.0.0.1/tcp/1234/webtransport"), // WebTransport over TCP? Is this a joke?
}
_, key := newIdentity(t)
tr, err := libp2pwebtransport.New(key, nil, network.NullResourceManager)
require.NoError(t, err)
defer tr.(io.Closer).Close()
for _, addr := range valid {
require.Truef(t, tr.CanDial(addr), "expected to be able to dial %s", addr)
}
for _, addr := range invalid {
require.Falsef(t, tr.CanDial(addr), "expected to not be able to dial %s", addr)
}
}
func TestListenAddrValidity(t *testing.T) {
valid := []ma.Multiaddr{
ma.StringCast("/ip6/::/udp/0/quic/webtransport/"),
ma.StringCast("/ip4/127.0.0.1/udp/1234/quic/webtransport/"),
}
invalid := []ma.Multiaddr{
ma.StringCast("/ip4/127.0.0.1/udp/1234"), // missing webtransport
ma.StringCast("/ip4/127.0.0.1/udp/1234/webtransport"), // missing quic
ma.StringCast("/ip4/127.0.0.1/tcp/1234/webtransport"), // WebTransport over TCP? Is this a joke?
ma.StringCast("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/" + randomMultihash(t)),
}
_, key := newIdentity(t)
tr, err := libp2pwebtransport.New(key, nil, network.NullResourceManager)
require.NoError(t, err)
defer tr.(io.Closer).Close()
for _, addr := range valid {
ln, err := tr.Listen(addr)
require.NoErrorf(t, err, "expected to be able to listen on %s", addr)
ln.Close()
}
for _, addr := range invalid {
_, err := tr.Listen(addr)
require.Errorf(t, err, "expected to not be able to listen on %s", addr)
}
}
func TestListenerAddrs(t *testing.T) {
_, key := newIdentity(t)
tr, err := libp2pwebtransport.New(key, nil, network.NullResourceManager)
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln1, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport"))
require.NoError(t, err)
ln2, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport"))
require.NoError(t, err)
hashes1 := extractCertHashes(ln1.Multiaddr())
require.Len(t, hashes1, 2)
hashes2 := extractCertHashes(ln2.Multiaddr())
require.Equal(t, hashes1, hashes2)
}
func TestResourceManagerDialing(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
rcmgr := mocknetwork.NewMockResourceManager(ctrl)
addr := ma.StringCast("/ip4/9.8.7.6/udp/1234/quic/webtransport")
p := peer.ID("foobar")
_, key := newIdentity(t)
tr, err := libp2pwebtransport.New(key, nil, rcmgr)
require.NoError(t, err)
defer tr.(io.Closer).Close()
scope := mocknetwork.NewMockConnManagementScope(ctrl)
rcmgr.EXPECT().OpenConnection(network.DirOutbound, false, addr).Return(scope, nil)
scope.EXPECT().SetPeer(p).Return(errors.New("denied"))
scope.EXPECT().Done()
_, err = tr.Dial(context.Background(), addr, p)
require.EqualError(t, err, "denied")
}
func TestResourceManagerListening(t *testing.T) {
clientID, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager)
require.NoError(t, err)
defer cl.(io.Closer).Close()
t.Run("blocking the connection", func(t *testing.T) {
serverID, key := newIdentity(t)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
rcmgr := mocknetwork.NewMockResourceManager(ctrl)
tr, err := libp2pwebtransport.New(key, nil, rcmgr)
require.NoError(t, err)
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport"))
require.NoError(t, err)
defer ln.Close()
rcmgr.EXPECT().OpenConnection(network.DirInbound, false, gomock.Any()).DoAndReturn(func(_ network.Direction, _ bool, addr ma.Multiaddr) (network.ConnManagementScope, error) {
_, err := addr.ValueForProtocol(ma.P_WEBTRANSPORT)
require.NoError(t, err, "expected a WebTransport multiaddr")
_, addrStr, err := manet.DialArgs(addr)
require.NoError(t, err)
host, _, err := net.SplitHostPort(addrStr)
require.NoError(t, err)
require.Equal(t, "127.0.0.1", host)
return nil, errors.New("denied")
})
_, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID)
require.EqualError(t, err, "received status 503")
})
t.Run("blocking the peer", func(t *testing.T) {
serverID, key := newIdentity(t)
ctrl := gomock.NewController(t)
defer ctrl.Finish()
rcmgr := mocknetwork.NewMockResourceManager(ctrl)
tr, err := libp2pwebtransport.New(key, nil, rcmgr)
require.NoError(t, err)
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport"))
require.NoError(t, err)
defer ln.Close()
serverDone := make(chan struct{})
scope := mocknetwork.NewMockConnManagementScope(ctrl)
rcmgr.EXPECT().OpenConnection(network.DirInbound, false, gomock.Any()).Return(scope, nil)
scope.EXPECT().SetPeer(clientID).Return(errors.New("denied"))
scope.EXPECT().Done().Do(func() { close(serverDone) })
// The handshake will complete, but the server will immediately close the connection.
conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID)
require.NoError(t, err)
defer conn.Close()
clientDone := make(chan struct{})
go func() {
defer close(clientDone)
_, err = conn.AcceptStream()
require.Error(t, err)
}()
select {
case <-clientDone:
case <-time.After(5 * time.Second):
t.Fatal("timeout")
}
select {
case <-serverDone:
case <-time.After(5 * time.Second):
t.Fatal("timeout")
}
})
}
// TODO: unify somehow. We do the same in libp2pquic.
//go:generate sh -c "mockgen -package libp2pwebtransport_test -destination mock_connection_gater_test.go github.com/libp2p/go-libp2p/core/connmgr ConnectionGater && goimports -w mock_connection_gater_test.go"
func TestConnectionGaterDialing(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
connGater := NewMockConnectionGater(ctrl)
serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager)
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport"))
require.NoError(t, err)
defer ln.Close()
connGater.EXPECT().InterceptSecured(network.DirOutbound, serverID, gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
require.Equal(t, stripCertHashes(ln.Multiaddr()), addrs.RemoteMultiaddr())
})
_, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, connGater, network.NullResourceManager)
require.NoError(t, err)
defer cl.(io.Closer).Close()
_, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID)
require.EqualError(t, err, "secured connection gated")
}
func TestConnectionGaterInterceptAccept(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
connGater := NewMockConnectionGater(ctrl)
serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, connGater, network.NullResourceManager)
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport"))
require.NoError(t, err)
defer ln.Close()
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
require.Equal(t, stripCertHashes(ln.Multiaddr()), addrs.LocalMultiaddr())
require.NotEqual(t, stripCertHashes(ln.Multiaddr()), addrs.RemoteMultiaddr())
})
_, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager)
require.NoError(t, err)
defer cl.(io.Closer).Close()
_, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID)
require.EqualError(t, err, "received status 403")
}
func TestConnectionGaterInterceptSecured(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
connGater := NewMockConnectionGater(ctrl)
serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, connGater, network.NullResourceManager)
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport"))
require.NoError(t, err)
defer ln.Close()
clientID, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager)
require.NoError(t, err)
defer cl.(io.Closer).Close()
connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true)
connGater.EXPECT().InterceptSecured(network.DirInbound, clientID, gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
require.Equal(t, stripCertHashes(ln.Multiaddr()), addrs.LocalMultiaddr())
require.NotEqual(t, stripCertHashes(ln.Multiaddr()), addrs.RemoteMultiaddr())
})
// The handshake will complete, but the server will immediately close the connection.
conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID)
require.NoError(t, err)
defer conn.Close()
done := make(chan struct{})
go func() {
defer close(done)
_, err = conn.AcceptStream()
require.Error(t, err)
}()
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("timeout")
}
}
func getTLSConf(t *testing.T, ip net.IP, start, end time.Time) *tls.Config {
t.Helper()
certTempl := &x509.Certificate{
SerialNumber: big.NewInt(1234),
Subject: pkix.Name{Organization: []string{"webtransport"}},
NotBefore: start,
NotAfter: end,
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
IPAddresses: []net.IP{ip},
}
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &priv.PublicKey, priv)
require.NoError(t, err)
cert, err := x509.ParseCertificate(caBytes)
require.NoError(t, err)
return &tls.Config{
Certificates: []tls.Certificate{{
Certificate: [][]byte{cert.Raw},
PrivateKey: priv,
Leaf: cert,
}},
}
}
func TestStaticTLSConf(t *testing.T) {
tlsConf := getTLSConf(t, net.ParseIP("127.0.0.1"), time.Now(), time.Now().Add(365*24*time.Hour))
serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager, libp2pwebtransport.WithTLSConfig(tlsConf))
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport"))
require.NoError(t, err)
defer ln.Close()
require.Empty(t, extractCertHashes(ln.Multiaddr()), "listener address shouldn't contain any certhash")
t.Run("fails when the certificate is invalid", func(t *testing.T) {
_, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager)
require.NoError(t, err)
defer cl.(io.Closer).Close()
_, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID)
require.Error(t, err)
if !strings.Contains(err.Error(), "certificate is not trusted") &&
!strings.Contains(err.Error(), "certificate signed by unknown authority") {
t.Fatalf("expected a certificate error, got %+v", err)
}
})
t.Run("fails when dialing with a wrong certhash", func(t *testing.T) {
_, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager)
require.NoError(t, err)
defer cl.(io.Closer).Close()
addr := ln.Multiaddr().Encapsulate(getCerthashComponent(t, []byte("foo")))
_, err = cl.Dial(context.Background(), addr, serverID)
require.Error(t, err)
require.Contains(t, err.Error(), "cert hash not found")
})
t.Run("accepts a valid TLS certificate", func(t *testing.T) {
_, key := newIdentity(t)
store := x509.NewCertPool()
store.AddCert(tlsConf.Certificates[0].Leaf)
tlsConf := &tls.Config{RootCAs: store}
cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager, libp2pwebtransport.WithTLSClientConfig(tlsConf))
require.NoError(t, err)
defer cl.(io.Closer).Close()
require.True(t, cl.CanDial(ln.Multiaddr()))
conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID)
require.NoError(t, err)
defer conn.Close()
})
}
func TestAcceptQueueFilledUp(t *testing.T) {
serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager)
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport"))
require.NoError(t, err)
defer ln.Close()
newConn := func() (tpt.CapableConn, error) {
t.Helper()
_, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager)
require.NoError(t, err)
defer cl.(io.Closer).Close()
return cl.Dial(context.Background(), ln.Multiaddr(), serverID)
}
for i := 0; i < 16; i++ {
conn, err := newConn()
require.NoError(t, err)
defer conn.Close()
}
conn, err := newConn()
if err == nil {
_, err = conn.AcceptStream()
}
require.Error(t, err)
}
Loading…
Cancel
Save