mirror of https://github.com/libp2p/go-libp2p.git
Marten Seemann
2 years ago
16 changed files with 2336 additions and 0 deletions
@ -0,0 +1,168 @@ |
|||
package libp2pwebtransport |
|||
|
|||
import ( |
|||
"bytes" |
|||
"context" |
|||
"crypto/sha256" |
|||
"crypto/tls" |
|||
"fmt" |
|||
"sync" |
|||
"time" |
|||
|
|||
"github.com/benbjohnson/clock" |
|||
ma "github.com/multiformats/go-multiaddr" |
|||
"github.com/multiformats/go-multihash" |
|||
) |
|||
|
|||
// Allow for a bit of clock skew.
|
|||
// When we generate a certificate, the NotBefore time is set to clockSkewAllowance before the current time.
|
|||
// Similarly, we stop using a certificate one clockSkewAllowance before its expiry time.
|
|||
const clockSkewAllowance = time.Hour |
|||
|
|||
type certConfig struct { |
|||
tlsConf *tls.Config |
|||
sha256 [32]byte // cached from the tlsConf
|
|||
} |
|||
|
|||
func (c *certConfig) Start() time.Time { return c.tlsConf.Certificates[0].Leaf.NotBefore } |
|||
func (c *certConfig) End() time.Time { return c.tlsConf.Certificates[0].Leaf.NotAfter } |
|||
|
|||
func newCertConfig(start, end time.Time) (*certConfig, error) { |
|||
conf, err := getTLSConf(start, end) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return &certConfig{ |
|||
tlsConf: conf, |
|||
sha256: sha256.Sum256(conf.Certificates[0].Leaf.Raw), |
|||
}, nil |
|||
} |
|||
|
|||
// Certificate renewal logic:
|
|||
// 1. On startup, we generate one cert that is valid from now (-1h, to allow for clock skew), and another
|
|||
// cert that is valid from the expiry date of the first certificate (again, with allowance for clock skew).
|
|||
// 2. Once we reach 1h before expiry of the first certificate, we switch over to the second certificate.
|
|||
// At the same time, we stop advertising the certhash of the first cert and generate the next cert.
|
|||
type certManager struct { |
|||
clock clock.Clock |
|||
ctx context.Context |
|||
ctxCancel context.CancelFunc |
|||
refCount sync.WaitGroup |
|||
|
|||
mx sync.RWMutex |
|||
lastConfig *certConfig // initially nil
|
|||
currentConfig *certConfig |
|||
nextConfig *certConfig // nil until we have passed half the certValidity of the current config
|
|||
addrComp ma.Multiaddr |
|||
} |
|||
|
|||
func newCertManager(clock clock.Clock) (*certManager, error) { |
|||
m := &certManager{clock: clock} |
|||
m.ctx, m.ctxCancel = context.WithCancel(context.Background()) |
|||
if err := m.init(); err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
m.background() |
|||
return m, nil |
|||
} |
|||
|
|||
func (m *certManager) init() error { |
|||
start := m.clock.Now().Add(-clockSkewAllowance) |
|||
var err error |
|||
m.nextConfig, err = newCertConfig(start, start.Add(certValidity)) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
return m.rollConfig() |
|||
} |
|||
|
|||
func (m *certManager) rollConfig() error { |
|||
// We stop using the current certificate clockSkewAllowance before its expiry time.
|
|||
// At this point, the next certificate needs to be valid for one clockSkewAllowance.
|
|||
nextStart := m.nextConfig.End().Add(-2 * clockSkewAllowance) |
|||
c, err := newCertConfig(nextStart, nextStart.Add(certValidity)) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
m.lastConfig = m.currentConfig |
|||
m.currentConfig = m.nextConfig |
|||
m.nextConfig = c |
|||
return m.cacheAddrComponent() |
|||
} |
|||
|
|||
func (m *certManager) background() { |
|||
d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(m.clock.Now()) |
|||
log.Debugw("setting timer", "duration", d.String()) |
|||
t := m.clock.Timer(d) |
|||
m.refCount.Add(1) |
|||
|
|||
go func() { |
|||
defer m.refCount.Done() |
|||
defer t.Stop() |
|||
|
|||
for { |
|||
select { |
|||
case <-m.ctx.Done(): |
|||
return |
|||
case now := <-t.C: |
|||
m.mx.Lock() |
|||
if err := m.rollConfig(); err != nil { |
|||
log.Errorw("rolling config failed", "error", err) |
|||
} |
|||
d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(now) |
|||
log.Debugw("rolling certificates", "next", d.String()) |
|||
t.Reset(d) |
|||
m.mx.Unlock() |
|||
} |
|||
} |
|||
}() |
|||
} |
|||
|
|||
func (m *certManager) GetConfig() *tls.Config { |
|||
m.mx.RLock() |
|||
defer m.mx.RUnlock() |
|||
return m.currentConfig.tlsConf |
|||
} |
|||
|
|||
func (m *certManager) AddrComponent() ma.Multiaddr { |
|||
m.mx.RLock() |
|||
defer m.mx.RUnlock() |
|||
return m.addrComp |
|||
} |
|||
|
|||
func (m *certManager) Verify(hashes []multihash.DecodedMultihash) error { |
|||
for _, h := range hashes { |
|||
if h.Code != multihash.SHA2_256 { |
|||
return fmt.Errorf("expected SHA256 hash, got %d", h.Code) |
|||
} |
|||
if !bytes.Equal(h.Digest, m.currentConfig.sha256[:]) && |
|||
(m.nextConfig == nil || !bytes.Equal(h.Digest, m.nextConfig.sha256[:])) && |
|||
(m.lastConfig == nil || !bytes.Equal(h.Digest, m.lastConfig.sha256[:])) { |
|||
return fmt.Errorf("found unexpected hash: %+x", h.Digest) |
|||
} |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
func (m *certManager) cacheAddrComponent() error { |
|||
addr, err := addrComponentForCert(m.currentConfig.sha256[:]) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
if m.nextConfig != nil { |
|||
comp, err := addrComponentForCert(m.nextConfig.sha256[:]) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
addr = addr.Encapsulate(comp) |
|||
} |
|||
m.addrComp = addr |
|||
return nil |
|||
} |
|||
|
|||
func (m *certManager) Close() error { |
|||
m.ctxCancel() |
|||
m.refCount.Wait() |
|||
return nil |
|||
} |
@ -0,0 +1,102 @@ |
|||
package libp2pwebtransport |
|||
|
|||
import ( |
|||
"crypto/sha256" |
|||
"crypto/tls" |
|||
"testing" |
|||
"time" |
|||
|
|||
"github.com/benbjohnson/clock" |
|||
ma "github.com/multiformats/go-multiaddr" |
|||
"github.com/multiformats/go-multibase" |
|||
"github.com/multiformats/go-multihash" |
|||
"github.com/stretchr/testify/require" |
|||
) |
|||
|
|||
func certificateHashFromTLSConfig(c *tls.Config) [32]byte { |
|||
return sha256.Sum256(c.Certificates[0].Certificate[0]) |
|||
} |
|||
|
|||
func splitMultiaddr(addr ma.Multiaddr) []ma.Component { |
|||
var components []ma.Component |
|||
ma.ForEach(addr, func(c ma.Component) bool { |
|||
components = append(components, c) |
|||
return true |
|||
}) |
|||
return components |
|||
} |
|||
|
|||
func certHashFromComponent(t *testing.T, comp ma.Component) []byte { |
|||
t.Helper() |
|||
_, data, err := multibase.Decode(comp.Value()) |
|||
require.NoError(t, err) |
|||
mh, err := multihash.Decode(data) |
|||
require.NoError(t, err) |
|||
require.Equal(t, uint64(multihash.SHA2_256), mh.Code) |
|||
return mh.Digest |
|||
} |
|||
|
|||
func TestInitialCert(t *testing.T) { |
|||
cl := clock.NewMock() |
|||
cl.Add(1234567 * time.Hour) |
|||
m, err := newCertManager(cl) |
|||
require.NoError(t, err) |
|||
defer m.Close() |
|||
|
|||
conf := m.GetConfig() |
|||
require.Len(t, conf.Certificates, 1) |
|||
cert := conf.Certificates[0] |
|||
require.Equal(t, cl.Now().Add(-clockSkewAllowance).UTC(), cert.Leaf.NotBefore) |
|||
require.Equal(t, cert.Leaf.NotBefore.Add(certValidity), cert.Leaf.NotAfter) |
|||
addr := m.AddrComponent() |
|||
components := splitMultiaddr(addr) |
|||
require.Len(t, components, 2) |
|||
require.Equal(t, ma.P_CERTHASH, components[0].Protocol().Code) |
|||
hash := certificateHashFromTLSConfig(conf) |
|||
require.Equal(t, hash[:], certHashFromComponent(t, components[0])) |
|||
require.Equal(t, ma.P_CERTHASH, components[1].Protocol().Code) |
|||
} |
|||
|
|||
func TestCertRenewal(t *testing.T) { |
|||
cl := clock.NewMock() |
|||
m, err := newCertManager(cl) |
|||
require.NoError(t, err) |
|||
defer m.Close() |
|||
|
|||
firstConf := m.GetConfig() |
|||
first := splitMultiaddr(m.AddrComponent()) |
|||
require.Len(t, first, 2) |
|||
require.NotEqual(t, first[0].Value(), first[1].Value(), "the hashes should differ") |
|||
// wait for a new certificate to be generated
|
|||
cl.Add(certValidity - 2*clockSkewAllowance - time.Second) |
|||
require.Never(t, func() bool { |
|||
for i, c := range splitMultiaddr(m.AddrComponent()) { |
|||
if c.Value() != first[i].Value() { |
|||
return true |
|||
} |
|||
} |
|||
return false |
|||
}, 100*time.Millisecond, 10*time.Millisecond) |
|||
cl.Add(2 * time.Second) |
|||
require.Eventually(t, func() bool { return m.GetConfig() != firstConf }, 200*time.Millisecond, 10*time.Millisecond) |
|||
secondConf := m.GetConfig() |
|||
|
|||
second := splitMultiaddr(m.AddrComponent()) |
|||
require.Len(t, second, 2) |
|||
for _, c := range second { |
|||
require.Equal(t, ma.P_CERTHASH, c.Protocol().Code) |
|||
} |
|||
// check that the 2nd certificate from the beginning was rolled over to be the 1st certificate
|
|||
require.Equal(t, first[1].Value(), second[0].Value()) |
|||
require.NotEqual(t, first[0].Value(), second[1].Value()) |
|||
|
|||
cl.Add(certValidity - 2*clockSkewAllowance + time.Second) |
|||
require.Eventually(t, func() bool { return m.GetConfig() != secondConf }, 200*time.Millisecond, 10*time.Millisecond) |
|||
third := splitMultiaddr(m.AddrComponent()) |
|||
require.Len(t, third, 2) |
|||
for _, c := range third { |
|||
require.Equal(t, ma.P_CERTHASH, c.Protocol().Code) |
|||
} |
|||
// check that the 2nd certificate from the beginning was rolled over to be the 1st certificate
|
|||
require.Equal(t, second[1].Value(), third[0].Value()) |
|||
} |
@ -0,0 +1,60 @@ |
|||
package libp2pwebtransport |
|||
|
|||
import ( |
|||
"context" |
|||
|
|||
"github.com/libp2p/go-libp2p/core/network" |
|||
tpt "github.com/libp2p/go-libp2p/core/transport" |
|||
|
|||
"github.com/marten-seemann/webtransport-go" |
|||
ma "github.com/multiformats/go-multiaddr" |
|||
) |
|||
|
|||
type connSecurityMultiaddrs struct { |
|||
network.ConnSecurity |
|||
network.ConnMultiaddrs |
|||
} |
|||
|
|||
type connMultiaddrs struct { |
|||
local, remote ma.Multiaddr |
|||
} |
|||
|
|||
var _ network.ConnMultiaddrs = &connMultiaddrs{} |
|||
|
|||
func (c *connMultiaddrs) LocalMultiaddr() ma.Multiaddr { return c.local } |
|||
func (c *connMultiaddrs) RemoteMultiaddr() ma.Multiaddr { return c.remote } |
|||
|
|||
type conn struct { |
|||
*connSecurityMultiaddrs |
|||
|
|||
transport tpt.Transport |
|||
session *webtransport.Session |
|||
|
|||
scope network.ConnScope |
|||
} |
|||
|
|||
var _ tpt.CapableConn = &conn{} |
|||
|
|||
func newConn(tr tpt.Transport, sess *webtransport.Session, sconn *connSecurityMultiaddrs, scope network.ConnScope) *conn { |
|||
return &conn{ |
|||
connSecurityMultiaddrs: sconn, |
|||
transport: tr, |
|||
session: sess, |
|||
scope: scope, |
|||
} |
|||
} |
|||
|
|||
func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { |
|||
str, err := c.session.OpenStreamSync(ctx) |
|||
return &stream{str}, err |
|||
} |
|||
|
|||
func (c *conn) AcceptStream() (network.MuxedStream, error) { |
|||
str, err := c.session.AcceptStream(context.Background()) |
|||
return &stream{str}, err |
|||
} |
|||
|
|||
func (c *conn) Close() error { return c.session.Close() } |
|||
func (c *conn) IsClosed() bool { return c.session.Context().Err() != nil } |
|||
func (c *conn) Scope() network.ConnScope { return c.scope } |
|||
func (c *conn) Transport() tpt.Transport { return c.transport } |
@ -0,0 +1,108 @@ |
|||
package libp2pwebtransport |
|||
|
|||
import ( |
|||
"bytes" |
|||
"crypto/ecdsa" |
|||
"crypto/elliptic" |
|||
"crypto/rand" |
|||
"crypto/sha256" |
|||
"crypto/tls" |
|||
"crypto/x509" |
|||
"crypto/x509/pkix" |
|||
"encoding/binary" |
|||
"errors" |
|||
"fmt" |
|||
"math/big" |
|||
"time" |
|||
|
|||
"github.com/multiformats/go-multihash" |
|||
) |
|||
|
|||
func getTLSConf(start, end time.Time) (*tls.Config, error) { |
|||
cert, priv, err := generateCert(start, end) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return &tls.Config{ |
|||
Certificates: []tls.Certificate{{ |
|||
Certificate: [][]byte{cert.Raw}, |
|||
PrivateKey: priv, |
|||
Leaf: cert, |
|||
}}, |
|||
}, nil |
|||
} |
|||
|
|||
func generateCert(start, end time.Time) (*x509.Certificate, *ecdsa.PrivateKey, error) { |
|||
b := make([]byte, 8) |
|||
if _, err := rand.Read(b); err != nil { |
|||
return nil, nil, err |
|||
} |
|||
serial := int64(binary.BigEndian.Uint64(b)) |
|||
if serial < 0 { |
|||
serial = -serial |
|||
} |
|||
certTempl := &x509.Certificate{ |
|||
SerialNumber: big.NewInt(serial), |
|||
Subject: pkix.Name{}, |
|||
NotBefore: start, |
|||
NotAfter: end, |
|||
IsCA: true, |
|||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, |
|||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, |
|||
BasicConstraintsValid: true, |
|||
} |
|||
caPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) |
|||
if err != nil { |
|||
return nil, nil, err |
|||
} |
|||
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &caPrivateKey.PublicKey, caPrivateKey) |
|||
if err != nil { |
|||
return nil, nil, err |
|||
} |
|||
ca, err := x509.ParseCertificate(caBytes) |
|||
if err != nil { |
|||
return nil, nil, err |
|||
} |
|||
return ca, caPrivateKey, nil |
|||
} |
|||
|
|||
func verifyRawCerts(rawCerts [][]byte, certHashes []multihash.DecodedMultihash) error { |
|||
if len(rawCerts) < 1 { |
|||
return errors.New("no cert") |
|||
} |
|||
leaf := rawCerts[len(rawCerts)-1] |
|||
// The W3C WebTransport specification currently only allows SHA-256 certificates for serverCertificateHashes.
|
|||
hash := sha256.Sum256(leaf) |
|||
var verified bool |
|||
for _, h := range certHashes { |
|||
if h.Code == multihash.SHA2_256 && bytes.Equal(h.Digest, hash[:]) { |
|||
verified = true |
|||
break |
|||
} |
|||
} |
|||
if !verified { |
|||
digests := make([][]byte, 0, len(certHashes)) |
|||
for _, h := range certHashes { |
|||
digests = append(digests, h.Digest) |
|||
} |
|||
return fmt.Errorf("cert hash not found: %#x (expected: %#x)", hash, digests) |
|||
} |
|||
|
|||
cert, err := x509.ParseCertificate(leaf) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
// TODO: is this the best (and complete?) way to identify RSA certificates?
|
|||
switch cert.SignatureAlgorithm { |
|||
case x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA, x509.MD2WithRSA, x509.MD5WithRSA: |
|||
return errors.New("cert uses RSA") |
|||
} |
|||
if l := cert.NotAfter.Sub(cert.NotBefore); l > 14*24*time.Hour { |
|||
return fmt.Errorf("cert must not be valid for longer than 14 days (NotBefore: %s, NotAfter: %s, Length: %s)", cert.NotBefore, cert.NotAfter, l) |
|||
} |
|||
now := time.Now() |
|||
if now.Before(cert.NotBefore) || now.After(cert.NotAfter) { |
|||
return fmt.Errorf("cert not valid (NotBefore: %s, NotAfter: %s)", cert.NotBefore, cert.NotAfter) |
|||
} |
|||
return nil |
|||
} |
@ -0,0 +1,132 @@ |
|||
package libp2pwebtransport |
|||
|
|||
import ( |
|||
"crypto" |
|||
"crypto/ecdsa" |
|||
"crypto/elliptic" |
|||
"crypto/rand" |
|||
"crypto/rsa" |
|||
"crypto/sha256" |
|||
"crypto/x509" |
|||
"crypto/x509/pkix" |
|||
"fmt" |
|||
"math/big" |
|||
mrand "math/rand" |
|||
"testing" |
|||
"time" |
|||
|
|||
"github.com/multiformats/go-multihash" |
|||
"github.com/stretchr/testify/require" |
|||
) |
|||
|
|||
func sha256Multihash(t *testing.T, b []byte) multihash.DecodedMultihash { |
|||
t.Helper() |
|||
hash := sha256.Sum256(b) |
|||
h, err := multihash.Encode(hash[:], multihash.SHA2_256) |
|||
require.NoError(t, err) |
|||
dh, err := multihash.Decode(h) |
|||
require.NoError(t, err) |
|||
return *dh |
|||
} |
|||
|
|||
func generateCertWithKey(t *testing.T, key crypto.PrivateKey, start, end time.Time) *x509.Certificate { |
|||
t.Helper() |
|||
serial := int64(mrand.Uint64()) |
|||
if serial < 0 { |
|||
serial = -serial |
|||
} |
|||
certTempl := &x509.Certificate{ |
|||
SerialNumber: big.NewInt(serial), |
|||
Subject: pkix.Name{}, |
|||
NotBefore: start, |
|||
NotAfter: end, |
|||
IsCA: true, |
|||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, |
|||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, |
|||
BasicConstraintsValid: true, |
|||
} |
|||
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, key.(interface{ Public() crypto.PublicKey }).Public(), key) |
|||
require.NoError(t, err) |
|||
ca, err := x509.ParseCertificate(caBytes) |
|||
require.NoError(t, err) |
|||
return ca |
|||
} |
|||
|
|||
func TestCertificateVerification(t *testing.T) { |
|||
now := time.Now() |
|||
ecdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) |
|||
require.NoError(t, err) |
|||
rsaKey, err := rsa.GenerateKey(rand.Reader, 1024) |
|||
require.NoError(t, err) |
|||
|
|||
t.Run("accepting a valid cert", func(t *testing.T) { |
|||
validCert := generateCertWithKey(t, ecdsaKey, now, now.Add(14*24*time.Hour)) |
|||
require.NoError(t, verifyRawCerts([][]byte{validCert.Raw}, []multihash.DecodedMultihash{sha256Multihash(t, validCert.Raw)})) |
|||
}) |
|||
|
|||
for _, tc := range [...]struct { |
|||
name string |
|||
cert *x509.Certificate |
|||
errStr string |
|||
}{ |
|||
{ |
|||
name: "validitity period too long", |
|||
cert: generateCertWithKey(t, ecdsaKey, now, now.Add(15*24*time.Hour)), |
|||
errStr: "cert must not be valid for longer than 14 days", |
|||
}, |
|||
{ |
|||
name: "uses RSA key", |
|||
cert: generateCertWithKey(t, rsaKey, now, now.Add(14*24*time.Hour)), |
|||
errStr: "RSA", |
|||
}, |
|||
{ |
|||
name: "expired certificate", |
|||
cert: generateCertWithKey(t, ecdsaKey, now.Add(-14*24*time.Hour), now), |
|||
errStr: "cert not valid", |
|||
}, |
|||
{ |
|||
name: "not yet valid", |
|||
cert: generateCertWithKey(t, ecdsaKey, now.Add(time.Hour), now.Add(time.Hour+14*24*time.Hour)), |
|||
errStr: "cert not valid", |
|||
}, |
|||
} { |
|||
tc := tc |
|||
t.Run(fmt.Sprintf("rejecting invalid certificates: %s", tc.name), func(t *testing.T) { |
|||
err := verifyRawCerts([][]byte{tc.cert.Raw}, []multihash.DecodedMultihash{sha256Multihash(t, tc.cert.Raw)}) |
|||
require.Error(t, err) |
|||
require.Contains(t, err.Error(), tc.errStr) |
|||
}) |
|||
} |
|||
|
|||
for _, tc := range [...]struct { |
|||
name string |
|||
certs [][]byte |
|||
hashes []multihash.DecodedMultihash |
|||
errStr string |
|||
}{ |
|||
{ |
|||
name: "no certificates", |
|||
hashes: []multihash.DecodedMultihash{sha256Multihash(t, []byte("foobar"))}, |
|||
errStr: "no cert", |
|||
}, |
|||
{ |
|||
name: "certificate not parseable", |
|||
certs: [][]byte{[]byte("foobar")}, |
|||
hashes: []multihash.DecodedMultihash{sha256Multihash(t, []byte("foobar"))}, |
|||
errStr: "x509: malformed certificate", |
|||
}, |
|||
{ |
|||
name: "hash mismatch", |
|||
certs: [][]byte{generateCertWithKey(t, ecdsaKey, now, now.Add(15*24*time.Hour)).Raw}, |
|||
hashes: []multihash.DecodedMultihash{sha256Multihash(t, []byte("foobar"))}, |
|||
errStr: "cert hash not found", |
|||
}, |
|||
} { |
|||
tc := tc |
|||
t.Run(fmt.Sprintf("rejecting invalid certificates: %s", tc.name), func(t *testing.T) { |
|||
err := verifyRawCerts(tc.certs, tc.hashes) |
|||
require.Error(t, err) |
|||
require.Contains(t, err.Error(), tc.errStr) |
|||
}) |
|||
} |
|||
} |
@ -0,0 +1,249 @@ |
|||
package libp2pwebtransport |
|||
|
|||
import ( |
|||
"context" |
|||
"crypto/tls" |
|||
"errors" |
|||
"fmt" |
|||
pb "github.com/marten-seemann/go-libp2p-webtransport/pb" |
|||
"github.com/multiformats/go-multihash" |
|||
"net" |
|||
"net/http" |
|||
"time" |
|||
|
|||
"github.com/libp2p/go-libp2p/core/connmgr" |
|||
"github.com/libp2p/go-libp2p/core/network" |
|||
tpt "github.com/libp2p/go-libp2p/core/transport" |
|||
"github.com/libp2p/go-libp2p/p2p/security/noise" |
|||
|
|||
"github.com/lucas-clemente/quic-go/http3" |
|||
"github.com/marten-seemann/webtransport-go" |
|||
ma "github.com/multiformats/go-multiaddr" |
|||
manet "github.com/multiformats/go-multiaddr/net" |
|||
) |
|||
|
|||
var errClosed = errors.New("closed") |
|||
|
|||
const queueLen = 16 |
|||
const handshakeTimeout = 10 * time.Second |
|||
|
|||
type listener struct { |
|||
transport tpt.Transport |
|||
noise *noise.Transport |
|||
certManager *certManager |
|||
tlsConf *tls.Config |
|||
isStaticTLSConf bool |
|||
|
|||
rcmgr network.ResourceManager |
|||
gater connmgr.ConnectionGater |
|||
|
|||
server webtransport.Server |
|||
|
|||
ctx context.Context |
|||
ctxCancel context.CancelFunc |
|||
|
|||
serverClosed chan struct{} // is closed when server.Serve returns
|
|||
|
|||
addr net.Addr |
|||
multiaddr ma.Multiaddr |
|||
|
|||
queue chan tpt.CapableConn |
|||
} |
|||
|
|||
var _ tpt.Listener = &listener{} |
|||
|
|||
func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Transport, certManager *certManager, tlsConf *tls.Config, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) (tpt.Listener, error) { |
|||
network, addr, err := manet.DialArgs(laddr) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
udpAddr, err := net.ResolveUDPAddr(network, addr) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
udpConn, err := net.ListenUDP(network, udpAddr) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
localMultiaddr, err := toWebtransportMultiaddr(udpConn.LocalAddr()) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
isStaticTLSConf := tlsConf != nil |
|||
if tlsConf == nil { |
|||
tlsConf = &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { |
|||
return certManager.GetConfig(), nil |
|||
}} |
|||
} |
|||
ln := &listener{ |
|||
transport: transport, |
|||
noise: noise, |
|||
certManager: certManager, |
|||
tlsConf: tlsConf, |
|||
isStaticTLSConf: isStaticTLSConf, |
|||
rcmgr: rcmgr, |
|||
gater: gater, |
|||
queue: make(chan tpt.CapableConn, queueLen), |
|||
serverClosed: make(chan struct{}), |
|||
addr: udpConn.LocalAddr(), |
|||
multiaddr: localMultiaddr, |
|||
server: webtransport.Server{H3: http3.Server{TLSConfig: tlsConf}}, |
|||
} |
|||
ln.ctx, ln.ctxCancel = context.WithCancel(context.Background()) |
|||
mux := http.NewServeMux() |
|||
mux.HandleFunc(webtransportHTTPEndpoint, ln.httpHandler) |
|||
ln.server.H3.Handler = mux |
|||
go func() { |
|||
defer close(ln.serverClosed) |
|||
defer func() { udpConn.Close() }() |
|||
if err := ln.server.Serve(udpConn); err != nil { |
|||
// TODO: only output if the server hasn't been closed
|
|||
log.Debugw("serving failed", "addr", udpConn.LocalAddr(), "error", err) |
|||
} |
|||
}() |
|||
return ln, nil |
|||
} |
|||
|
|||
func (l *listener) httpHandler(w http.ResponseWriter, r *http.Request) { |
|||
remoteMultiaddr, err := stringToWebtransportMultiaddr(r.RemoteAddr) |
|||
if err != nil { |
|||
// This should never happen.
|
|||
log.Errorw("converting remote address failed", "remote", r.RemoteAddr, "error", err) |
|||
w.WriteHeader(http.StatusBadRequest) |
|||
return |
|||
} |
|||
if l.gater != nil && !l.gater.InterceptAccept(&connMultiaddrs{local: l.multiaddr, remote: remoteMultiaddr}) { |
|||
w.WriteHeader(http.StatusForbidden) |
|||
return |
|||
} |
|||
|
|||
connScope, err := l.rcmgr.OpenConnection(network.DirInbound, false, remoteMultiaddr) |
|||
if err != nil { |
|||
log.Debugw("resource manager blocked incoming connection", "addr", r.RemoteAddr, "error", err) |
|||
w.WriteHeader(http.StatusServiceUnavailable) |
|||
return |
|||
} |
|||
|
|||
// TODO: check ?type=multistream URL param
|
|||
sess, err := l.server.Upgrade(w, r) |
|||
if err != nil { |
|||
log.Debugw("upgrade failed", "error", err) |
|||
// TODO: think about the status code to use here
|
|||
w.WriteHeader(500) |
|||
connScope.Done() |
|||
return |
|||
} |
|||
ctx, cancel := context.WithTimeout(l.ctx, handshakeTimeout) |
|||
sconn, err := l.handshake(ctx, sess) |
|||
if err != nil { |
|||
cancel() |
|||
log.Debugw("handshake failed", "error", err) |
|||
sess.Close() |
|||
connScope.Done() |
|||
return |
|||
} |
|||
cancel() |
|||
|
|||
if l.gater != nil && !l.gater.InterceptSecured(network.DirInbound, sconn.RemotePeer(), sconn) { |
|||
// TODO: can we close with a specific error here?
|
|||
sess.Close() |
|||
connScope.Done() |
|||
return |
|||
} |
|||
|
|||
if err := connScope.SetPeer(sconn.RemotePeer()); err != nil { |
|||
log.Debugw("resource manager blocked incoming connection for peer", "peer", sconn.RemotePeer(), "addr", r.RemoteAddr, "error", err) |
|||
sess.Close() |
|||
connScope.Done() |
|||
return |
|||
} |
|||
|
|||
select { |
|||
case l.queue <- newConn(l.transport, sess, sconn, connScope): |
|||
default: |
|||
log.Debugw("accept queue full, dropping incoming connection", "peer", sconn.RemotePeer(), "addr", r.RemoteAddr, "error", err) |
|||
sess.Close() |
|||
connScope.Done() |
|||
} |
|||
} |
|||
|
|||
func (l *listener) Accept() (tpt.CapableConn, error) { |
|||
select { |
|||
case <-l.ctx.Done(): |
|||
return nil, errClosed |
|||
case c := <-l.queue: |
|||
return c, nil |
|||
} |
|||
} |
|||
|
|||
func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (*connSecurityMultiaddrs, error) { |
|||
local, err := toWebtransportMultiaddr(sess.LocalAddr()) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("error determiniting local addr: %w", err) |
|||
} |
|||
remote, err := toWebtransportMultiaddr(sess.RemoteAddr()) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("error determiniting remote addr: %w", err) |
|||
} |
|||
|
|||
str, err := sess.AcceptStream(ctx) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
n, err := l.noise.WithSessionOptions(noise.EarlyData(newEarlyDataReceiver(l.checkEarlyData))) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("failed to initialize Noise session: %w", err) |
|||
} |
|||
c, err := n.SecureInbound(ctx, &webtransportStream{Stream: str, wsess: sess}, "") |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
return &connSecurityMultiaddrs{ |
|||
ConnSecurity: c, |
|||
ConnMultiaddrs: &connMultiaddrs{local: local, remote: remote}, |
|||
}, nil |
|||
} |
|||
|
|||
func (l *listener) checkEarlyData(b []byte) error { |
|||
var msg pb.WebTransport |
|||
if err := msg.Unmarshal(b); err != nil { |
|||
fmt.Println(1) |
|||
return fmt.Errorf("failed to unmarshal early data protobuf: %w", err) |
|||
} |
|||
|
|||
if l.isStaticTLSConf { |
|||
if len(msg.CertHashes) > 0 { |
|||
return errors.New("using static TLS config, didn't expect any certificate hashes") |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
hashes := make([]multihash.DecodedMultihash, 0, len(msg.CertHashes)) |
|||
for _, h := range msg.CertHashes { |
|||
dh, err := multihash.Decode(h) |
|||
if err != nil { |
|||
return fmt.Errorf("failed to decode hash: %w", err) |
|||
} |
|||
hashes = append(hashes, *dh) |
|||
} |
|||
return l.certManager.Verify(hashes) |
|||
} |
|||
|
|||
func (l *listener) Addr() net.Addr { |
|||
return l.addr |
|||
} |
|||
|
|||
func (l *listener) Multiaddr() ma.Multiaddr { |
|||
if l.certManager == nil { |
|||
return l.multiaddr |
|||
} |
|||
return l.multiaddr.Encapsulate(l.certManager.AddrComponent()) |
|||
} |
|||
|
|||
func (l *listener) Close() error { |
|||
l.ctxCancel() |
|||
err := l.server.Close() |
|||
<-l.serverClosed |
|||
return err |
|||
} |
@ -0,0 +1,109 @@ |
|||
// Code generated by MockGen. DO NOT EDIT.
|
|||
// Source: github.com/libp2p/go-libp2p/core/connmgr (interfaces: ConnectionGater)
|
|||
|
|||
// Package libp2pwebtransport_test is a generated GoMock package.
|
|||
package libp2pwebtransport_test |
|||
|
|||
import ( |
|||
reflect "reflect" |
|||
|
|||
gomock "github.com/golang/mock/gomock" |
|||
control "github.com/libp2p/go-libp2p/core/control" |
|||
network "github.com/libp2p/go-libp2p/core/network" |
|||
peer "github.com/libp2p/go-libp2p/core/peer" |
|||
multiaddr "github.com/multiformats/go-multiaddr" |
|||
) |
|||
|
|||
// MockConnectionGater is a mock of ConnectionGater interface.
|
|||
type MockConnectionGater struct { |
|||
ctrl *gomock.Controller |
|||
recorder *MockConnectionGaterMockRecorder |
|||
} |
|||
|
|||
// MockConnectionGaterMockRecorder is the mock recorder for MockConnectionGater.
|
|||
type MockConnectionGaterMockRecorder struct { |
|||
mock *MockConnectionGater |
|||
} |
|||
|
|||
// NewMockConnectionGater creates a new mock instance.
|
|||
func NewMockConnectionGater(ctrl *gomock.Controller) *MockConnectionGater { |
|||
mock := &MockConnectionGater{ctrl: ctrl} |
|||
mock.recorder = &MockConnectionGaterMockRecorder{mock} |
|||
return mock |
|||
} |
|||
|
|||
// EXPECT returns an object that allows the caller to indicate expected use.
|
|||
func (m *MockConnectionGater) EXPECT() *MockConnectionGaterMockRecorder { |
|||
return m.recorder |
|||
} |
|||
|
|||
// InterceptAccept mocks base method.
|
|||
func (m *MockConnectionGater) InterceptAccept(arg0 network.ConnMultiaddrs) bool { |
|||
m.ctrl.T.Helper() |
|||
ret := m.ctrl.Call(m, "InterceptAccept", arg0) |
|||
ret0, _ := ret[0].(bool) |
|||
return ret0 |
|||
} |
|||
|
|||
// InterceptAccept indicates an expected call of InterceptAccept.
|
|||
func (mr *MockConnectionGaterMockRecorder) InterceptAccept(arg0 interface{}) *gomock.Call { |
|||
mr.mock.ctrl.T.Helper() |
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptAccept", reflect.TypeOf((*MockConnectionGater)(nil).InterceptAccept), arg0) |
|||
} |
|||
|
|||
// InterceptAddrDial mocks base method.
|
|||
func (m *MockConnectionGater) InterceptAddrDial(arg0 peer.ID, arg1 multiaddr.Multiaddr) bool { |
|||
m.ctrl.T.Helper() |
|||
ret := m.ctrl.Call(m, "InterceptAddrDial", arg0, arg1) |
|||
ret0, _ := ret[0].(bool) |
|||
return ret0 |
|||
} |
|||
|
|||
// InterceptAddrDial indicates an expected call of InterceptAddrDial.
|
|||
func (mr *MockConnectionGaterMockRecorder) InterceptAddrDial(arg0, arg1 interface{}) *gomock.Call { |
|||
mr.mock.ctrl.T.Helper() |
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptAddrDial", reflect.TypeOf((*MockConnectionGater)(nil).InterceptAddrDial), arg0, arg1) |
|||
} |
|||
|
|||
// InterceptPeerDial mocks base method.
|
|||
func (m *MockConnectionGater) InterceptPeerDial(arg0 peer.ID) bool { |
|||
m.ctrl.T.Helper() |
|||
ret := m.ctrl.Call(m, "InterceptPeerDial", arg0) |
|||
ret0, _ := ret[0].(bool) |
|||
return ret0 |
|||
} |
|||
|
|||
// InterceptPeerDial indicates an expected call of InterceptPeerDial.
|
|||
func (mr *MockConnectionGaterMockRecorder) InterceptPeerDial(arg0 interface{}) *gomock.Call { |
|||
mr.mock.ctrl.T.Helper() |
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptPeerDial", reflect.TypeOf((*MockConnectionGater)(nil).InterceptPeerDial), arg0) |
|||
} |
|||
|
|||
// InterceptSecured mocks base method.
|
|||
func (m *MockConnectionGater) InterceptSecured(arg0 network.Direction, arg1 peer.ID, arg2 network.ConnMultiaddrs) bool { |
|||
m.ctrl.T.Helper() |
|||
ret := m.ctrl.Call(m, "InterceptSecured", arg0, arg1, arg2) |
|||
ret0, _ := ret[0].(bool) |
|||
return ret0 |
|||
} |
|||
|
|||
// InterceptSecured indicates an expected call of InterceptSecured.
|
|||
func (mr *MockConnectionGaterMockRecorder) InterceptSecured(arg0, arg1, arg2 interface{}) *gomock.Call { |
|||
mr.mock.ctrl.T.Helper() |
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptSecured", reflect.TypeOf((*MockConnectionGater)(nil).InterceptSecured), arg0, arg1, arg2) |
|||
} |
|||
|
|||
// InterceptUpgraded mocks base method.
|
|||
func (m *MockConnectionGater) InterceptUpgraded(arg0 network.Conn) (bool, control.DisconnectReason) { |
|||
m.ctrl.T.Helper() |
|||
ret := m.ctrl.Call(m, "InterceptUpgraded", arg0) |
|||
ret0, _ := ret[0].(bool) |
|||
ret1, _ := ret[1].(control.DisconnectReason) |
|||
return ret0, ret1 |
|||
} |
|||
|
|||
// InterceptUpgraded indicates an expected call of InterceptUpgraded.
|
|||
func (mr *MockConnectionGaterMockRecorder) InterceptUpgraded(arg0 interface{}) *gomock.Call { |
|||
mr.mock.ctrl.T.Helper() |
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InterceptUpgraded", reflect.TypeOf((*MockConnectionGater)(nil).InterceptUpgraded), arg0) |
|||
} |
@ -0,0 +1,80 @@ |
|||
package libp2pwebtransport |
|||
|
|||
import ( |
|||
"errors" |
|||
"fmt" |
|||
"net" |
|||
"strconv" |
|||
|
|||
ma "github.com/multiformats/go-multiaddr" |
|||
mafmt "github.com/multiformats/go-multiaddr-fmt" |
|||
manet "github.com/multiformats/go-multiaddr/net" |
|||
"github.com/multiformats/go-multibase" |
|||
"github.com/multiformats/go-multihash" |
|||
) |
|||
|
|||
var webtransportMA = ma.StringCast("/quic/webtransport") |
|||
|
|||
var webtransportMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_UDP), mafmt.Base(ma.P_QUIC), mafmt.Base(ma.P_WEBTRANSPORT)) |
|||
|
|||
func toWebtransportMultiaddr(na net.Addr) (ma.Multiaddr, error) { |
|||
addr, err := manet.FromNetAddr(na) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if _, err := addr.ValueForProtocol(ma.P_UDP); err != nil { |
|||
return nil, errors.New("not a UDP address") |
|||
} |
|||
return addr.Encapsulate(webtransportMA), nil |
|||
} |
|||
|
|||
func stringToWebtransportMultiaddr(str string) (ma.Multiaddr, error) { |
|||
host, portStr, err := net.SplitHostPort(str) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
port, err := strconv.ParseInt(portStr, 10, 32) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
ip := net.ParseIP(host) |
|||
if ip == nil { |
|||
return nil, errors.New("failed to parse IP") |
|||
} |
|||
return toWebtransportMultiaddr(&net.UDPAddr{IP: ip, Port: int(port)}) |
|||
} |
|||
|
|||
func extractCertHashes(addr ma.Multiaddr) ([]multihash.DecodedMultihash, error) { |
|||
certHashesStr := make([]string, 0, 2) |
|||
ma.ForEach(addr, func(c ma.Component) bool { |
|||
if c.Protocol().Code == ma.P_CERTHASH { |
|||
certHashesStr = append(certHashesStr, c.Value()) |
|||
} |
|||
return true |
|||
}) |
|||
certHashes := make([]multihash.DecodedMultihash, 0, len(certHashesStr)) |
|||
for _, s := range certHashesStr { |
|||
_, ch, err := multibase.Decode(s) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("failed to multibase-decode certificate hash: %w", err) |
|||
} |
|||
dh, err := multihash.Decode(ch) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("failed to multihash-decode certificate hash: %w", err) |
|||
} |
|||
certHashes = append(certHashes, *dh) |
|||
} |
|||
return certHashes, nil |
|||
} |
|||
|
|||
func addrComponentForCert(hash []byte) (ma.Multiaddr, error) { |
|||
mh, err := multihash.Encode(hash, multihash.SHA2_256) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
certStr, err := multibase.Encode(multibase.Base58BTC, mh) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr) |
|||
} |
@ -0,0 +1,76 @@ |
|||
package libp2pwebtransport |
|||
|
|||
import ( |
|||
"fmt" |
|||
"net" |
|||
"testing" |
|||
|
|||
ma "github.com/multiformats/go-multiaddr" |
|||
"github.com/multiformats/go-multibase" |
|||
"github.com/multiformats/go-multihash" |
|||
"github.com/stretchr/testify/require" |
|||
) |
|||
|
|||
func TestWebtransportMultiaddr(t *testing.T) { |
|||
t.Run("valid", func(t *testing.T) { |
|||
addr, err := toWebtransportMultiaddr(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}) |
|||
require.NoError(t, err) |
|||
require.Equal(t, "/ip4/127.0.0.1/udp/1337/quic/webtransport", addr.String()) |
|||
}) |
|||
|
|||
t.Run("invalid", func(t *testing.T) { |
|||
_, err := toWebtransportMultiaddr(&net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}) |
|||
require.EqualError(t, err, "not a UDP address") |
|||
}) |
|||
} |
|||
|
|||
func TestWebtransportMultiaddrFromString(t *testing.T) { |
|||
t.Run("valid", func(t *testing.T) { |
|||
addr, err := stringToWebtransportMultiaddr("1.2.3.4:60042") |
|||
require.NoError(t, err) |
|||
require.Equal(t, "/ip4/1.2.3.4/udp/60042/quic/webtransport", addr.String()) |
|||
}) |
|||
|
|||
t.Run("invalid", func(t *testing.T) { |
|||
for _, addr := range [...]string{ |
|||
"1.2.3.4", // missing port
|
|||
"1.2.3.4:123456", // invalid port
|
|||
":1234", // missing IP
|
|||
"foobar", |
|||
} { |
|||
_, err := stringToWebtransportMultiaddr(addr) |
|||
require.Error(t, err) |
|||
} |
|||
}) |
|||
} |
|||
|
|||
func encodeCertHash(t *testing.T, b []byte, mh uint64, mb multibase.Encoding) string { |
|||
t.Helper() |
|||
h, err := multihash.Encode(b, mh) |
|||
require.NoError(t, err) |
|||
str, err := multibase.Encode(mb, h) |
|||
require.NoError(t, err) |
|||
return str |
|||
} |
|||
|
|||
func TestExtractCertHashes(t *testing.T) { |
|||
fooHash := encodeCertHash(t, []byte("foo"), multihash.SHA2_256, multibase.Base58BTC) |
|||
barHash := encodeCertHash(t, []byte("bar"), multihash.BLAKE2B_MAX, multibase.Base32) |
|||
|
|||
// valid cases
|
|||
for _, tc := range [...]struct { |
|||
addr string |
|||
hashes []string |
|||
}{ |
|||
{addr: "/ip4/127.0.0.1/udp/1234/quic/webtransport"}, |
|||
{addr: fmt.Sprintf("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/%s", fooHash), hashes: []string{"foo"}}, |
|||
{addr: fmt.Sprintf("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/%s/certhash/%s", fooHash, barHash), hashes: []string{"foo", "bar"}}, |
|||
} { |
|||
ch, err := extractCertHashes(ma.StringCast(tc.addr)) |
|||
require.NoError(t, err) |
|||
require.Len(t, ch, len(tc.hashes)) |
|||
for i, h := range tc.hashes { |
|||
require.Equal(t, h, string(ch[i].Digest)) |
|||
} |
|||
} |
|||
} |
@ -0,0 +1,34 @@ |
|||
package libp2pwebtransport |
|||
|
|||
import ( |
|||
"context" |
|||
"github.com/libp2p/go-libp2p/core/peer" |
|||
"github.com/libp2p/go-libp2p/p2p/security/noise" |
|||
"net" |
|||
) |
|||
|
|||
type earlyDataHandler struct { |
|||
earlyData []byte |
|||
receive func([]byte) error |
|||
} |
|||
|
|||
var _ noise.EarlyDataHandler = &earlyDataHandler{} |
|||
|
|||
func newEarlyDataSender(earlyData []byte) noise.EarlyDataHandler { |
|||
return &earlyDataHandler{earlyData: earlyData} |
|||
} |
|||
|
|||
func newEarlyDataReceiver(receive func([]byte) error) noise.EarlyDataHandler { |
|||
return &earlyDataHandler{receive: receive} |
|||
} |
|||
|
|||
func (e *earlyDataHandler) Send(context.Context, net.Conn, peer.ID) []byte { |
|||
return e.earlyData |
|||
} |
|||
|
|||
func (e *earlyDataHandler) Received(_ context.Context, _ net.Conn, b []byte) error { |
|||
if e.receive == nil { |
|||
return nil |
|||
} |
|||
return e.receive(b) |
|||
} |
@ -0,0 +1,11 @@ |
|||
PB = $(wildcard *.proto) |
|||
GO = $(PB:.proto=.pb.go) |
|||
|
|||
all: $(GO) |
|||
|
|||
%.pb.go: %.proto |
|||
protoc --proto_path=$(PWD)/../..:. --gogofaster_out=. $< |
|||
|
|||
clean: |
|||
rm -f *.pb.go |
|||
rm -f *.go |
@ -0,0 +1,315 @@ |
|||
// Code generated by protoc-gen-gogo. DO NOT EDIT.
|
|||
// source: webtransport.proto
|
|||
|
|||
package webtransport |
|||
|
|||
import ( |
|||
fmt "fmt" |
|||
proto "github.com/gogo/protobuf/proto" |
|||
io "io" |
|||
math "math" |
|||
math_bits "math/bits" |
|||
) |
|||
|
|||
// Reference imports to suppress errors if they are not otherwise used.
|
|||
var _ = proto.Marshal |
|||
var _ = fmt.Errorf |
|||
var _ = math.Inf |
|||
|
|||
// This is a compile-time assertion to ensure that this generated file
|
|||
// is compatible with the proto package it is being compiled against.
|
|||
// A compilation error at this line likely means your copy of the
|
|||
// proto package needs to be updated.
|
|||
const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package
|
|||
|
|||
type WebTransport struct { |
|||
CertHashes [][]byte `protobuf:"bytes,1,rep,name=cert_hashes,json=certHashes" json:"cert_hashes,omitempty"` |
|||
} |
|||
|
|||
func (m *WebTransport) Reset() { *m = WebTransport{} } |
|||
func (m *WebTransport) String() string { return proto.CompactTextString(m) } |
|||
func (*WebTransport) ProtoMessage() {} |
|||
func (*WebTransport) Descriptor() ([]byte, []int) { |
|||
return fileDescriptor_db878920ab41a4f3, []int{0} |
|||
} |
|||
func (m *WebTransport) XXX_Unmarshal(b []byte) error { |
|||
return m.Unmarshal(b) |
|||
} |
|||
func (m *WebTransport) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { |
|||
if deterministic { |
|||
return xxx_messageInfo_WebTransport.Marshal(b, m, deterministic) |
|||
} else { |
|||
b = b[:cap(b)] |
|||
n, err := m.MarshalToSizedBuffer(b) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return b[:n], nil |
|||
} |
|||
} |
|||
func (m *WebTransport) XXX_Merge(src proto.Message) { |
|||
xxx_messageInfo_WebTransport.Merge(m, src) |
|||
} |
|||
func (m *WebTransport) XXX_Size() int { |
|||
return m.Size() |
|||
} |
|||
func (m *WebTransport) XXX_DiscardUnknown() { |
|||
xxx_messageInfo_WebTransport.DiscardUnknown(m) |
|||
} |
|||
|
|||
var xxx_messageInfo_WebTransport proto.InternalMessageInfo |
|||
|
|||
func (m *WebTransport) GetCertHashes() [][]byte { |
|||
if m != nil { |
|||
return m.CertHashes |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
func init() { |
|||
proto.RegisterType((*WebTransport)(nil), "WebTransport") |
|||
} |
|||
|
|||
func init() { proto.RegisterFile("webtransport.proto", fileDescriptor_db878920ab41a4f3) } |
|||
|
|||
var fileDescriptor_db878920ab41a4f3 = []byte{ |
|||
// 109 bytes of a gzipped FileDescriptorProto
|
|||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x2a, 0x4f, 0x4d, 0x2a, |
|||
0x29, 0x4a, 0xcc, 0x2b, 0x2e, 0xc8, 0x2f, 0x2a, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x57, 0xd2, |
|||
0xe7, 0xe2, 0x09, 0x4f, 0x4d, 0x0a, 0x81, 0x89, 0x0a, 0xc9, 0x73, 0x71, 0x27, 0xa7, 0x16, 0x95, |
|||
0xc4, 0x67, 0x24, 0x16, 0x67, 0xa4, 0x16, 0x4b, 0x30, 0x2a, 0x30, 0x6b, 0xf0, 0x04, 0x71, 0x81, |
|||
0x84, 0x3c, 0xc0, 0x22, 0x4e, 0x12, 0x27, 0x1e, 0xc9, 0x31, 0x5e, 0x78, 0x24, 0xc7, 0xf8, 0xe0, |
|||
0x91, 0x1c, 0xe3, 0x84, 0xc7, 0x72, 0x0c, 0x17, 0x1e, 0xcb, 0x31, 0xdc, 0x78, 0x2c, 0xc7, 0x00, |
|||
0x08, 0x00, 0x00, 0xff, 0xff, 0x50, 0x77, 0xe5, 0x52, 0x5f, 0x00, 0x00, 0x00, |
|||
} |
|||
|
|||
func (m *WebTransport) Marshal() (dAtA []byte, err error) { |
|||
size := m.Size() |
|||
dAtA = make([]byte, size) |
|||
n, err := m.MarshalToSizedBuffer(dAtA[:size]) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return dAtA[:n], nil |
|||
} |
|||
|
|||
func (m *WebTransport) MarshalTo(dAtA []byte) (int, error) { |
|||
size := m.Size() |
|||
return m.MarshalToSizedBuffer(dAtA[:size]) |
|||
} |
|||
|
|||
func (m *WebTransport) MarshalToSizedBuffer(dAtA []byte) (int, error) { |
|||
i := len(dAtA) |
|||
_ = i |
|||
var l int |
|||
_ = l |
|||
if len(m.CertHashes) > 0 { |
|||
for iNdEx := len(m.CertHashes) - 1; iNdEx >= 0; iNdEx-- { |
|||
i -= len(m.CertHashes[iNdEx]) |
|||
copy(dAtA[i:], m.CertHashes[iNdEx]) |
|||
i = encodeVarintWebtransport(dAtA, i, uint64(len(m.CertHashes[iNdEx]))) |
|||
i-- |
|||
dAtA[i] = 0xa |
|||
} |
|||
} |
|||
return len(dAtA) - i, nil |
|||
} |
|||
|
|||
func encodeVarintWebtransport(dAtA []byte, offset int, v uint64) int { |
|||
offset -= sovWebtransport(v) |
|||
base := offset |
|||
for v >= 1<<7 { |
|||
dAtA[offset] = uint8(v&0x7f | 0x80) |
|||
v >>= 7 |
|||
offset++ |
|||
} |
|||
dAtA[offset] = uint8(v) |
|||
return base |
|||
} |
|||
func (m *WebTransport) Size() (n int) { |
|||
if m == nil { |
|||
return 0 |
|||
} |
|||
var l int |
|||
_ = l |
|||
if len(m.CertHashes) > 0 { |
|||
for _, b := range m.CertHashes { |
|||
l = len(b) |
|||
n += 1 + l + sovWebtransport(uint64(l)) |
|||
} |
|||
} |
|||
return n |
|||
} |
|||
|
|||
func sovWebtransport(x uint64) (n int) { |
|||
return (math_bits.Len64(x|1) + 6) / 7 |
|||
} |
|||
func sozWebtransport(x uint64) (n int) { |
|||
return sovWebtransport(uint64((x << 1) ^ uint64((int64(x) >> 63)))) |
|||
} |
|||
func (m *WebTransport) Unmarshal(dAtA []byte) error { |
|||
l := len(dAtA) |
|||
iNdEx := 0 |
|||
for iNdEx < l { |
|||
preIndex := iNdEx |
|||
var wire uint64 |
|||
for shift := uint(0); ; shift += 7 { |
|||
if shift >= 64 { |
|||
return ErrIntOverflowWebtransport |
|||
} |
|||
if iNdEx >= l { |
|||
return io.ErrUnexpectedEOF |
|||
} |
|||
b := dAtA[iNdEx] |
|||
iNdEx++ |
|||
wire |= uint64(b&0x7F) << shift |
|||
if b < 0x80 { |
|||
break |
|||
} |
|||
} |
|||
fieldNum := int32(wire >> 3) |
|||
wireType := int(wire & 0x7) |
|||
if wireType == 4 { |
|||
return fmt.Errorf("proto: WebTransport: wiretype end group for non-group") |
|||
} |
|||
if fieldNum <= 0 { |
|||
return fmt.Errorf("proto: WebTransport: illegal tag %d (wire type %d)", fieldNum, wire) |
|||
} |
|||
switch fieldNum { |
|||
case 1: |
|||
if wireType != 2 { |
|||
return fmt.Errorf("proto: wrong wireType = %d for field CertHashes", wireType) |
|||
} |
|||
var byteLen int |
|||
for shift := uint(0); ; shift += 7 { |
|||
if shift >= 64 { |
|||
return ErrIntOverflowWebtransport |
|||
} |
|||
if iNdEx >= l { |
|||
return io.ErrUnexpectedEOF |
|||
} |
|||
b := dAtA[iNdEx] |
|||
iNdEx++ |
|||
byteLen |= int(b&0x7F) << shift |
|||
if b < 0x80 { |
|||
break |
|||
} |
|||
} |
|||
if byteLen < 0 { |
|||
return ErrInvalidLengthWebtransport |
|||
} |
|||
postIndex := iNdEx + byteLen |
|||
if postIndex < 0 { |
|||
return ErrInvalidLengthWebtransport |
|||
} |
|||
if postIndex > l { |
|||
return io.ErrUnexpectedEOF |
|||
} |
|||
m.CertHashes = append(m.CertHashes, make([]byte, postIndex-iNdEx)) |
|||
copy(m.CertHashes[len(m.CertHashes)-1], dAtA[iNdEx:postIndex]) |
|||
iNdEx = postIndex |
|||
default: |
|||
iNdEx = preIndex |
|||
skippy, err := skipWebtransport(dAtA[iNdEx:]) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
if (skippy < 0) || (iNdEx+skippy) < 0 { |
|||
return ErrInvalidLengthWebtransport |
|||
} |
|||
if (iNdEx + skippy) > l { |
|||
return io.ErrUnexpectedEOF |
|||
} |
|||
iNdEx += skippy |
|||
} |
|||
} |
|||
|
|||
if iNdEx > l { |
|||
return io.ErrUnexpectedEOF |
|||
} |
|||
return nil |
|||
} |
|||
func skipWebtransport(dAtA []byte) (n int, err error) { |
|||
l := len(dAtA) |
|||
iNdEx := 0 |
|||
depth := 0 |
|||
for iNdEx < l { |
|||
var wire uint64 |
|||
for shift := uint(0); ; shift += 7 { |
|||
if shift >= 64 { |
|||
return 0, ErrIntOverflowWebtransport |
|||
} |
|||
if iNdEx >= l { |
|||
return 0, io.ErrUnexpectedEOF |
|||
} |
|||
b := dAtA[iNdEx] |
|||
iNdEx++ |
|||
wire |= (uint64(b) & 0x7F) << shift |
|||
if b < 0x80 { |
|||
break |
|||
} |
|||
} |
|||
wireType := int(wire & 0x7) |
|||
switch wireType { |
|||
case 0: |
|||
for shift := uint(0); ; shift += 7 { |
|||
if shift >= 64 { |
|||
return 0, ErrIntOverflowWebtransport |
|||
} |
|||
if iNdEx >= l { |
|||
return 0, io.ErrUnexpectedEOF |
|||
} |
|||
iNdEx++ |
|||
if dAtA[iNdEx-1] < 0x80 { |
|||
break |
|||
} |
|||
} |
|||
case 1: |
|||
iNdEx += 8 |
|||
case 2: |
|||
var length int |
|||
for shift := uint(0); ; shift += 7 { |
|||
if shift >= 64 { |
|||
return 0, ErrIntOverflowWebtransport |
|||
} |
|||
if iNdEx >= l { |
|||
return 0, io.ErrUnexpectedEOF |
|||
} |
|||
b := dAtA[iNdEx] |
|||
iNdEx++ |
|||
length |= (int(b) & 0x7F) << shift |
|||
if b < 0x80 { |
|||
break |
|||
} |
|||
} |
|||
if length < 0 { |
|||
return 0, ErrInvalidLengthWebtransport |
|||
} |
|||
iNdEx += length |
|||
case 3: |
|||
depth++ |
|||
case 4: |
|||
if depth == 0 { |
|||
return 0, ErrUnexpectedEndOfGroupWebtransport |
|||
} |
|||
depth-- |
|||
case 5: |
|||
iNdEx += 4 |
|||
default: |
|||
return 0, fmt.Errorf("proto: illegal wireType %d", wireType) |
|||
} |
|||
if iNdEx < 0 { |
|||
return 0, ErrInvalidLengthWebtransport |
|||
} |
|||
if depth == 0 { |
|||
return iNdEx, nil |
|||
} |
|||
} |
|||
return 0, io.ErrUnexpectedEOF |
|||
} |
|||
|
|||
var ( |
|||
ErrInvalidLengthWebtransport = fmt.Errorf("proto: negative length found during unmarshaling") |
|||
ErrIntOverflowWebtransport = fmt.Errorf("proto: integer overflow") |
|||
ErrUnexpectedEndOfGroupWebtransport = fmt.Errorf("proto: unexpected end of group") |
|||
) |
@ -0,0 +1,5 @@ |
|||
syntax = "proto2"; |
|||
|
|||
message WebTransport { |
|||
repeated bytes cert_hashes = 1; |
|||
} |
@ -0,0 +1,71 @@ |
|||
package libp2pwebtransport |
|||
|
|||
import ( |
|||
"errors" |
|||
"net" |
|||
|
|||
"github.com/marten-seemann/webtransport-go" |
|||
|
|||
"github.com/libp2p/go-libp2p/core/network" |
|||
) |
|||
|
|||
const ( |
|||
reset webtransport.ErrorCode = 0 |
|||
) |
|||
|
|||
type webtransportStream struct { |
|||
webtransport.Stream |
|||
wsess *webtransport.Session |
|||
} |
|||
|
|||
var _ net.Conn = &webtransportStream{} |
|||
|
|||
func (s *webtransportStream) LocalAddr() net.Addr { |
|||
return s.wsess.LocalAddr() |
|||
} |
|||
|
|||
func (s *webtransportStream) RemoteAddr() net.Addr { |
|||
return s.wsess.RemoteAddr() |
|||
} |
|||
|
|||
type stream struct { |
|||
webtransport.Stream |
|||
} |
|||
|
|||
var _ network.MuxedStream = &stream{} |
|||
|
|||
func (s *stream) Read(b []byte) (n int, err error) { |
|||
n, err = s.Stream.Read(b) |
|||
if err != nil && errors.Is(err, &webtransport.StreamError{}) { |
|||
err = network.ErrReset |
|||
} |
|||
return n, err |
|||
} |
|||
|
|||
func (s *stream) Write(b []byte) (n int, err error) { |
|||
n, err = s.Stream.Write(b) |
|||
if err != nil && errors.Is(err, &webtransport.StreamError{}) { |
|||
err = network.ErrReset |
|||
} |
|||
return n, err |
|||
} |
|||
|
|||
func (s *stream) Reset() error { |
|||
s.Stream.CancelRead(reset) |
|||
s.Stream.CancelWrite(reset) |
|||
return nil |
|||
} |
|||
|
|||
func (s *stream) Close() error { |
|||
s.Stream.CancelRead(reset) |
|||
return s.Stream.Close() |
|||
} |
|||
|
|||
func (s *stream) CloseRead() error { |
|||
s.Stream.CancelRead(reset) |
|||
return nil |
|||
} |
|||
|
|||
func (s *stream) CloseWrite() error { |
|||
return s.Stream.Close() |
|||
} |
@ -0,0 +1,271 @@ |
|||
package libp2pwebtransport |
|||
|
|||
import ( |
|||
"context" |
|||
"crypto/tls" |
|||
"crypto/x509" |
|||
"fmt" |
|||
"io" |
|||
"sync" |
|||
"time" |
|||
|
|||
pb "github.com/marten-seemann/go-libp2p-webtransport/pb" |
|||
|
|||
"github.com/libp2p/go-libp2p/core/connmgr" |
|||
ic "github.com/libp2p/go-libp2p/core/crypto" |
|||
"github.com/libp2p/go-libp2p/core/network" |
|||
"github.com/libp2p/go-libp2p/core/peer" |
|||
tpt "github.com/libp2p/go-libp2p/core/transport" |
|||
"github.com/libp2p/go-libp2p/p2p/security/noise" |
|||
|
|||
"github.com/benbjohnson/clock" |
|||
logging "github.com/ipfs/go-log/v2" |
|||
"github.com/lucas-clemente/quic-go/http3" |
|||
"github.com/marten-seemann/webtransport-go" |
|||
ma "github.com/multiformats/go-multiaddr" |
|||
manet "github.com/multiformats/go-multiaddr/net" |
|||
"github.com/multiformats/go-multihash" |
|||
) |
|||
|
|||
var log = logging.Logger("webtransport") |
|||
|
|||
const webtransportHTTPEndpoint = "/.well-known/libp2p-webtransport" |
|||
|
|||
const certValidity = 14 * 24 * time.Hour |
|||
|
|||
type Option func(*transport) error |
|||
|
|||
func WithClock(cl clock.Clock) Option { |
|||
return func(t *transport) error { |
|||
t.clock = cl |
|||
return nil |
|||
} |
|||
} |
|||
|
|||
// WithTLSConfig sets a tls.Config used for listening.
|
|||
// When used, the certificate from that config will be used, and no /certhash will be added to the listener's multiaddr.
|
|||
// This is most useful when running a listener that has a valid (CA-signed) certificate.
|
|||
func WithTLSConfig(c *tls.Config) Option { |
|||
return func(t *transport) error { |
|||
t.staticTLSConf = c |
|||
return nil |
|||
} |
|||
} |
|||
|
|||
// WithTLSClientConfig sets a custom tls.Config used for dialing.
|
|||
// This option is most useful for setting a custom tls.Config.RootCAs certificate pool.
|
|||
// When dialing a multiaddr that contains a /certhash component, this library will set InsecureSkipVerify and
|
|||
// overwrite the VerifyPeerCertificate callback.
|
|||
func WithTLSClientConfig(c *tls.Config) Option { |
|||
return func(t *transport) error { |
|||
t.tlsClientConf = c |
|||
return nil |
|||
} |
|||
} |
|||
|
|||
type transport struct { |
|||
privKey ic.PrivKey |
|||
pid peer.ID |
|||
clock clock.Clock |
|||
|
|||
rcmgr network.ResourceManager |
|||
gater connmgr.ConnectionGater |
|||
|
|||
listenOnce sync.Once |
|||
listenOnceErr error |
|||
certManager *certManager |
|||
staticTLSConf *tls.Config |
|||
tlsClientConf *tls.Config |
|||
|
|||
noise *noise.Transport |
|||
} |
|||
|
|||
var _ tpt.Transport = &transport{} |
|||
var _ io.Closer = &transport{} |
|||
|
|||
func New(key ic.PrivKey, gater connmgr.ConnectionGater, rcmgr network.ResourceManager, opts ...Option) (tpt.Transport, error) { |
|||
id, err := peer.IDFromPrivateKey(key) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
t := &transport{ |
|||
pid: id, |
|||
privKey: key, |
|||
rcmgr: rcmgr, |
|||
gater: gater, |
|||
clock: clock.New(), |
|||
} |
|||
for _, opt := range opts { |
|||
if err := opt(t); err != nil { |
|||
return nil, err |
|||
} |
|||
} |
|||
n, err := noise.New(key) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
t.noise = n |
|||
return t, nil |
|||
} |
|||
|
|||
func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { |
|||
_, addr, err := manet.DialArgs(raddr) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
certHashes, err := extractCertHashes(raddr) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false, raddr) |
|||
if err != nil { |
|||
log.Debugw("resource manager blocked outgoing connection", "peer", p, "addr", raddr, "error", err) |
|||
return nil, err |
|||
} |
|||
if err := scope.SetPeer(p); err != nil { |
|||
log.Debugw("resource manager blocked outgoing connection for peer", "peer", p, "addr", raddr, "error", err) |
|||
scope.Done() |
|||
return nil, err |
|||
} |
|||
|
|||
sess, err := t.dial(ctx, addr, certHashes) |
|||
if err != nil { |
|||
scope.Done() |
|||
return nil, err |
|||
} |
|||
sconn, err := t.upgrade(ctx, sess, p, certHashes) |
|||
if err != nil { |
|||
sess.Close() |
|||
scope.Done() |
|||
return nil, err |
|||
} |
|||
if t.gater != nil && !t.gater.InterceptSecured(network.DirOutbound, p, sconn) { |
|||
// TODO: can we close with a specific error here?
|
|||
sess.Close() |
|||
scope.Done() |
|||
return nil, fmt.Errorf("secured connection gated") |
|||
} |
|||
|
|||
return newConn(t, sess, sconn, scope), nil |
|||
} |
|||
|
|||
func (t *transport) dial(ctx context.Context, addr string, certHashes []multihash.DecodedMultihash) (*webtransport.Session, error) { |
|||
url := fmt.Sprintf("https://%s%s", addr, webtransportHTTPEndpoint) |
|||
var tlsConf *tls.Config |
|||
if t.tlsClientConf != nil { |
|||
tlsConf = t.tlsClientConf.Clone() |
|||
} else { |
|||
tlsConf = &tls.Config{} |
|||
} |
|||
|
|||
if len(certHashes) > 0 { |
|||
// This is not insecure. We verify the certificate ourselves.
|
|||
// See https://www.w3.org/TR/webtransport/#certificate-hashes.
|
|||
tlsConf.InsecureSkipVerify = true |
|||
tlsConf.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { |
|||
return verifyRawCerts(rawCerts, certHashes) |
|||
} |
|||
} |
|||
dialer := webtransport.Dialer{ |
|||
RoundTripper: &http3.RoundTripper{TLSClientConfig: tlsConf}, |
|||
} |
|||
rsp, sess, err := dialer.Dial(ctx, url, nil) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if rsp.StatusCode < 200 || rsp.StatusCode > 299 { |
|||
return nil, fmt.Errorf("invalid response status code: %d", rsp.StatusCode) |
|||
} |
|||
return sess, err |
|||
} |
|||
|
|||
func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash) (*connSecurityMultiaddrs, error) { |
|||
local, err := toWebtransportMultiaddr(sess.LocalAddr()) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("error determiniting local addr: %w", err) |
|||
} |
|||
remote, err := toWebtransportMultiaddr(sess.RemoteAddr()) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("error determiniting remote addr: %w", err) |
|||
} |
|||
|
|||
str, err := sess.OpenStreamSync(ctx) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
// Now run a Noise handshake (using early data) and send all the certificate hashes that we would have accepted.
|
|||
// The server will verify that it advertised all of these certificate hashes.
|
|||
msg := pb.WebTransport{CertHashes: make([][]byte, 0, len(certHashes))} |
|||
for _, certHash := range certHashes { |
|||
h, err := multihash.Encode(certHash.Digest, certHash.Code) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("failed to encode certificate hash: %w", err) |
|||
} |
|||
msg.CertHashes = append(msg.CertHashes, h) |
|||
} |
|||
msgBytes, err := msg.Marshal() |
|||
if err != nil { |
|||
return nil, fmt.Errorf("failed to marshal WebTransport protobuf: %w", err) |
|||
} |
|||
n, err := t.noise.WithSessionOptions(noise.EarlyData(newEarlyDataSender(msgBytes))) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("failed to create Noise transport: %w", err) |
|||
} |
|||
c, err := n.SecureOutbound(ctx, &webtransportStream{Stream: str, wsess: sess}, p) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return &connSecurityMultiaddrs{ |
|||
ConnSecurity: c, |
|||
ConnMultiaddrs: &connMultiaddrs{local: local, remote: remote}, |
|||
}, nil |
|||
} |
|||
|
|||
func (t *transport) CanDial(addr ma.Multiaddr) bool { |
|||
var numHashes int |
|||
ma.ForEach(addr, func(c ma.Component) bool { |
|||
if c.Protocol().Code == ma.P_CERTHASH { |
|||
numHashes++ |
|||
} |
|||
return true |
|||
}) |
|||
// Remove the /certhash components from the multiaddr.
|
|||
// If the multiaddr doesn't contain any certhashes, the node might have a CA-signed certificate.
|
|||
for i := 0; i < numHashes; i++ { |
|||
addr, _ = ma.SplitLast(addr) |
|||
} |
|||
return webtransportMatcher.Matches(addr) |
|||
} |
|||
|
|||
func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { |
|||
if !webtransportMatcher.Matches(laddr) { |
|||
return nil, fmt.Errorf("cannot listen on non-WebTransport addr: %s", laddr) |
|||
} |
|||
if t.staticTLSConf == nil { |
|||
t.listenOnce.Do(func() { |
|||
t.certManager, t.listenOnceErr = newCertManager(t.clock) |
|||
}) |
|||
if t.listenOnceErr != nil { |
|||
return nil, t.listenOnceErr |
|||
} |
|||
} |
|||
return newListener(laddr, t, t.noise, t.certManager, t.staticTLSConf, t.gater, t.rcmgr) |
|||
} |
|||
|
|||
func (t *transport) Protocols() []int { |
|||
return []int{ma.P_WEBTRANSPORT} |
|||
} |
|||
|
|||
func (t *transport) Proxy() bool { |
|||
return false |
|||
} |
|||
|
|||
func (t *transport) Close() error { |
|||
t.listenOnce.Do(func() {}) |
|||
if t.certManager != nil { |
|||
return t.certManager.Close() |
|||
} |
|||
return nil |
|||
} |
@ -0,0 +1,545 @@ |
|||
package libp2pwebtransport_test |
|||
|
|||
import ( |
|||
"context" |
|||
"crypto/ecdsa" |
|||
"crypto/elliptic" |
|||
"crypto/rand" |
|||
"crypto/sha256" |
|||
"crypto/tls" |
|||
"crypto/x509" |
|||
"crypto/x509/pkix" |
|||
"errors" |
|||
"fmt" |
|||
"io" |
|||
"math/big" |
|||
"net" |
|||
"strings" |
|||
"testing" |
|||
"time" |
|||
|
|||
libp2pwebtransport "github.com/marten-seemann/go-libp2p-webtransport" |
|||
|
|||
ic "github.com/libp2p/go-libp2p/core/crypto" |
|||
"github.com/libp2p/go-libp2p/core/network" |
|||
mocknetwork "github.com/libp2p/go-libp2p/core/network/mocks" |
|||
"github.com/libp2p/go-libp2p/core/peer" |
|||
tpt "github.com/libp2p/go-libp2p/core/transport" |
|||
|
|||
"github.com/golang/mock/gomock" |
|||
ma "github.com/multiformats/go-multiaddr" |
|||
manet "github.com/multiformats/go-multiaddr/net" |
|||
"github.com/multiformats/go-multibase" |
|||
"github.com/multiformats/go-multihash" |
|||
|
|||
"github.com/stretchr/testify/require" |
|||
) |
|||
|
|||
func newIdentity(t *testing.T) (peer.ID, ic.PrivKey) { |
|||
key, _, err := ic.GenerateEd25519Key(rand.Reader) |
|||
require.NoError(t, err) |
|||
id, err := peer.IDFromPrivateKey(key) |
|||
require.NoError(t, err) |
|||
return id, key |
|||
} |
|||
|
|||
func randomMultihash(t *testing.T) string { |
|||
t.Helper() |
|||
b := make([]byte, 16) |
|||
rand.Read(b) |
|||
h, err := multihash.Encode(b, multihash.KECCAK_224) |
|||
require.NoError(t, err) |
|||
s, err := multibase.Encode(multibase.Base32hex, h) |
|||
require.NoError(t, err) |
|||
return s |
|||
} |
|||
|
|||
func extractCertHashes(addr ma.Multiaddr) []string { |
|||
var certHashesStr []string |
|||
ma.ForEach(addr, func(c ma.Component) bool { |
|||
if c.Protocol().Code == ma.P_CERTHASH { |
|||
certHashesStr = append(certHashesStr, c.Value()) |
|||
} |
|||
return true |
|||
}) |
|||
return certHashesStr |
|||
} |
|||
|
|||
func stripCertHashes(addr ma.Multiaddr) ma.Multiaddr { |
|||
for { |
|||
_, err := addr.ValueForProtocol(ma.P_CERTHASH) |
|||
if err != nil { |
|||
return addr |
|||
} |
|||
addr, _ = ma.SplitLast(addr) |
|||
} |
|||
} |
|||
|
|||
// create a /certhash multiaddr component using the SHA256 of foobar
|
|||
func getCerthashComponent(t *testing.T, b []byte) ma.Multiaddr { |
|||
t.Helper() |
|||
h := sha256.Sum256(b) |
|||
mh, err := multihash.Encode(h[:], multihash.SHA2_256) |
|||
require.NoError(t, err) |
|||
certStr, err := multibase.Encode(multibase.Base58BTC, mh) |
|||
require.NoError(t, err) |
|||
ha, err := ma.NewComponent(ma.ProtocolWithCode(ma.P_CERTHASH).Name, certStr) |
|||
require.NoError(t, err) |
|||
return ha |
|||
} |
|||
|
|||
func TestTransport(t *testing.T) { |
|||
serverID, serverKey := newIdentity(t) |
|||
tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager) |
|||
require.NoError(t, err) |
|||
defer tr.(io.Closer).Close() |
|||
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) |
|||
require.NoError(t, err) |
|||
defer ln.Close() |
|||
|
|||
addrChan := make(chan ma.Multiaddr) |
|||
go func() { |
|||
_, clientKey := newIdentity(t) |
|||
tr2, err := libp2pwebtransport.New(clientKey, nil, network.NullResourceManager) |
|||
require.NoError(t, err) |
|||
defer tr2.(io.Closer).Close() |
|||
|
|||
conn, err := tr2.Dial(context.Background(), ln.Multiaddr(), serverID) |
|||
require.NoError(t, err) |
|||
str, err := conn.OpenStream(context.Background()) |
|||
require.NoError(t, err) |
|||
_, err = str.Write([]byte("foobar")) |
|||
require.NoError(t, err) |
|||
require.NoError(t, str.Close()) |
|||
|
|||
// check RemoteMultiaddr
|
|||
_, addr, err := manet.DialArgs(ln.Multiaddr()) |
|||
require.NoError(t, err) |
|||
_, port, err := net.SplitHostPort(addr) |
|||
require.NoError(t, err) |
|||
require.Equal(t, ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/%s/quic/webtransport", port)), conn.RemoteMultiaddr()) |
|||
addrChan <- conn.RemoteMultiaddr() |
|||
}() |
|||
|
|||
conn, err := ln.Accept() |
|||
require.NoError(t, err) |
|||
require.False(t, conn.IsClosed()) |
|||
str, err := conn.AcceptStream() |
|||
require.NoError(t, err) |
|||
data, err := io.ReadAll(str) |
|||
require.NoError(t, err) |
|||
require.Equal(t, "foobar", string(data)) |
|||
require.Equal(t, <-addrChan, conn.LocalMultiaddr()) |
|||
require.NoError(t, conn.Close()) |
|||
require.True(t, conn.IsClosed()) |
|||
} |
|||
|
|||
func TestHashVerification(t *testing.T) { |
|||
serverID, serverKey := newIdentity(t) |
|||
tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager) |
|||
require.NoError(t, err) |
|||
defer tr.(io.Closer).Close() |
|||
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) |
|||
require.NoError(t, err) |
|||
done := make(chan struct{}) |
|||
go func() { |
|||
defer close(done) |
|||
_, err := ln.Accept() |
|||
require.Error(t, err) |
|||
}() |
|||
|
|||
_, clientKey := newIdentity(t) |
|||
tr2, err := libp2pwebtransport.New(clientKey, nil, network.NullResourceManager) |
|||
require.NoError(t, err) |
|||
defer tr2.(io.Closer).Close() |
|||
|
|||
foobarHash := getCerthashComponent(t, []byte("foobar")) |
|||
|
|||
t.Run("fails using only a wrong hash", func(t *testing.T) { |
|||
// replace the certificate hash in the multiaddr with a fake hash
|
|||
addr := stripCertHashes(ln.Multiaddr()).Encapsulate(foobarHash) |
|||
_, err := tr2.Dial(context.Background(), addr, serverID) |
|||
require.Error(t, err) |
|||
require.Contains(t, err.Error(), "CRYPTO_ERROR (0x12a): cert hash not found") |
|||
}) |
|||
|
|||
t.Run("fails when adding a wrong hash", func(t *testing.T) { |
|||
_, err := tr2.Dial(context.Background(), ln.Multiaddr().Encapsulate(foobarHash), serverID) |
|||
require.Error(t, err) |
|||
}) |
|||
|
|||
require.NoError(t, ln.Close()) |
|||
<-done |
|||
} |
|||
|
|||
func TestCanDial(t *testing.T) { |
|||
valid := []ma.Multiaddr{ |
|||
ma.StringCast("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/" + randomMultihash(t)), |
|||
ma.StringCast("/ip6/b16b:8255:efc6:9cd5:1a54:ee86:2d7a:c2e6/udp/1234/quic/webtransport/certhash/" + randomMultihash(t)), |
|||
ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/%s/certhash/%s/certhash/%s", randomMultihash(t), randomMultihash(t), randomMultihash(t))), |
|||
ma.StringCast("/ip4/127.0.0.1/udp/1234/quic/webtransport"), // no certificate hash
|
|||
} |
|||
|
|||
invalid := []ma.Multiaddr{ |
|||
ma.StringCast("/ip4/127.0.0.1/udp/1234"), // missing webtransport
|
|||
ma.StringCast("/ip4/127.0.0.1/udp/1234/webtransport"), // missing quic
|
|||
ma.StringCast("/ip4/127.0.0.1/tcp/1234/webtransport"), // WebTransport over TCP? Is this a joke?
|
|||
} |
|||
|
|||
_, key := newIdentity(t) |
|||
tr, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) |
|||
require.NoError(t, err) |
|||
defer tr.(io.Closer).Close() |
|||
|
|||
for _, addr := range valid { |
|||
require.Truef(t, tr.CanDial(addr), "expected to be able to dial %s", addr) |
|||
} |
|||
for _, addr := range invalid { |
|||
require.Falsef(t, tr.CanDial(addr), "expected to not be able to dial %s", addr) |
|||
} |
|||
} |
|||
|
|||
func TestListenAddrValidity(t *testing.T) { |
|||
valid := []ma.Multiaddr{ |
|||
ma.StringCast("/ip6/::/udp/0/quic/webtransport/"), |
|||
ma.StringCast("/ip4/127.0.0.1/udp/1234/quic/webtransport/"), |
|||
} |
|||
|
|||
invalid := []ma.Multiaddr{ |
|||
ma.StringCast("/ip4/127.0.0.1/udp/1234"), // missing webtransport
|
|||
ma.StringCast("/ip4/127.0.0.1/udp/1234/webtransport"), // missing quic
|
|||
ma.StringCast("/ip4/127.0.0.1/tcp/1234/webtransport"), // WebTransport over TCP? Is this a joke?
|
|||
ma.StringCast("/ip4/127.0.0.1/udp/1234/quic/webtransport/certhash/" + randomMultihash(t)), |
|||
} |
|||
|
|||
_, key := newIdentity(t) |
|||
tr, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) |
|||
require.NoError(t, err) |
|||
defer tr.(io.Closer).Close() |
|||
|
|||
for _, addr := range valid { |
|||
ln, err := tr.Listen(addr) |
|||
require.NoErrorf(t, err, "expected to be able to listen on %s", addr) |
|||
ln.Close() |
|||
} |
|||
for _, addr := range invalid { |
|||
_, err := tr.Listen(addr) |
|||
require.Errorf(t, err, "expected to not be able to listen on %s", addr) |
|||
} |
|||
} |
|||
|
|||
func TestListenerAddrs(t *testing.T) { |
|||
_, key := newIdentity(t) |
|||
tr, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) |
|||
require.NoError(t, err) |
|||
defer tr.(io.Closer).Close() |
|||
|
|||
ln1, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) |
|||
require.NoError(t, err) |
|||
ln2, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) |
|||
require.NoError(t, err) |
|||
hashes1 := extractCertHashes(ln1.Multiaddr()) |
|||
require.Len(t, hashes1, 2) |
|||
hashes2 := extractCertHashes(ln2.Multiaddr()) |
|||
require.Equal(t, hashes1, hashes2) |
|||
} |
|||
|
|||
func TestResourceManagerDialing(t *testing.T) { |
|||
ctrl := gomock.NewController(t) |
|||
defer ctrl.Finish() |
|||
rcmgr := mocknetwork.NewMockResourceManager(ctrl) |
|||
|
|||
addr := ma.StringCast("/ip4/9.8.7.6/udp/1234/quic/webtransport") |
|||
p := peer.ID("foobar") |
|||
|
|||
_, key := newIdentity(t) |
|||
tr, err := libp2pwebtransport.New(key, nil, rcmgr) |
|||
require.NoError(t, err) |
|||
defer tr.(io.Closer).Close() |
|||
|
|||
scope := mocknetwork.NewMockConnManagementScope(ctrl) |
|||
rcmgr.EXPECT().OpenConnection(network.DirOutbound, false, addr).Return(scope, nil) |
|||
scope.EXPECT().SetPeer(p).Return(errors.New("denied")) |
|||
scope.EXPECT().Done() |
|||
|
|||
_, err = tr.Dial(context.Background(), addr, p) |
|||
require.EqualError(t, err, "denied") |
|||
} |
|||
|
|||
func TestResourceManagerListening(t *testing.T) { |
|||
clientID, key := newIdentity(t) |
|||
cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) |
|||
require.NoError(t, err) |
|||
defer cl.(io.Closer).Close() |
|||
|
|||
t.Run("blocking the connection", func(t *testing.T) { |
|||
serverID, key := newIdentity(t) |
|||
ctrl := gomock.NewController(t) |
|||
defer ctrl.Finish() |
|||
rcmgr := mocknetwork.NewMockResourceManager(ctrl) |
|||
tr, err := libp2pwebtransport.New(key, nil, rcmgr) |
|||
require.NoError(t, err) |
|||
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) |
|||
require.NoError(t, err) |
|||
defer ln.Close() |
|||
|
|||
rcmgr.EXPECT().OpenConnection(network.DirInbound, false, gomock.Any()).DoAndReturn(func(_ network.Direction, _ bool, addr ma.Multiaddr) (network.ConnManagementScope, error) { |
|||
_, err := addr.ValueForProtocol(ma.P_WEBTRANSPORT) |
|||
require.NoError(t, err, "expected a WebTransport multiaddr") |
|||
_, addrStr, err := manet.DialArgs(addr) |
|||
require.NoError(t, err) |
|||
host, _, err := net.SplitHostPort(addrStr) |
|||
require.NoError(t, err) |
|||
require.Equal(t, "127.0.0.1", host) |
|||
return nil, errors.New("denied") |
|||
}) |
|||
|
|||
_, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID) |
|||
require.EqualError(t, err, "received status 503") |
|||
}) |
|||
|
|||
t.Run("blocking the peer", func(t *testing.T) { |
|||
serverID, key := newIdentity(t) |
|||
ctrl := gomock.NewController(t) |
|||
defer ctrl.Finish() |
|||
rcmgr := mocknetwork.NewMockResourceManager(ctrl) |
|||
tr, err := libp2pwebtransport.New(key, nil, rcmgr) |
|||
require.NoError(t, err) |
|||
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) |
|||
require.NoError(t, err) |
|||
defer ln.Close() |
|||
|
|||
serverDone := make(chan struct{}) |
|||
scope := mocknetwork.NewMockConnManagementScope(ctrl) |
|||
rcmgr.EXPECT().OpenConnection(network.DirInbound, false, gomock.Any()).Return(scope, nil) |
|||
scope.EXPECT().SetPeer(clientID).Return(errors.New("denied")) |
|||
scope.EXPECT().Done().Do(func() { close(serverDone) }) |
|||
|
|||
// The handshake will complete, but the server will immediately close the connection.
|
|||
conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID) |
|||
require.NoError(t, err) |
|||
defer conn.Close() |
|||
clientDone := make(chan struct{}) |
|||
go func() { |
|||
defer close(clientDone) |
|||
_, err = conn.AcceptStream() |
|||
require.Error(t, err) |
|||
}() |
|||
select { |
|||
case <-clientDone: |
|||
case <-time.After(5 * time.Second): |
|||
t.Fatal("timeout") |
|||
} |
|||
select { |
|||
case <-serverDone: |
|||
case <-time.After(5 * time.Second): |
|||
t.Fatal("timeout") |
|||
} |
|||
}) |
|||
} |
|||
|
|||
// TODO: unify somehow. We do the same in libp2pquic.
|
|||
//go:generate sh -c "mockgen -package libp2pwebtransport_test -destination mock_connection_gater_test.go github.com/libp2p/go-libp2p/core/connmgr ConnectionGater && goimports -w mock_connection_gater_test.go"
|
|||
|
|||
func TestConnectionGaterDialing(t *testing.T) { |
|||
ctrl := gomock.NewController(t) |
|||
defer ctrl.Finish() |
|||
connGater := NewMockConnectionGater(ctrl) |
|||
|
|||
serverID, serverKey := newIdentity(t) |
|||
tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager) |
|||
require.NoError(t, err) |
|||
defer tr.(io.Closer).Close() |
|||
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) |
|||
require.NoError(t, err) |
|||
defer ln.Close() |
|||
|
|||
connGater.EXPECT().InterceptSecured(network.DirOutbound, serverID, gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) { |
|||
require.Equal(t, stripCertHashes(ln.Multiaddr()), addrs.RemoteMultiaddr()) |
|||
}) |
|||
_, key := newIdentity(t) |
|||
cl, err := libp2pwebtransport.New(key, connGater, network.NullResourceManager) |
|||
require.NoError(t, err) |
|||
defer cl.(io.Closer).Close() |
|||
_, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID) |
|||
require.EqualError(t, err, "secured connection gated") |
|||
} |
|||
|
|||
func TestConnectionGaterInterceptAccept(t *testing.T) { |
|||
ctrl := gomock.NewController(t) |
|||
defer ctrl.Finish() |
|||
connGater := NewMockConnectionGater(ctrl) |
|||
|
|||
serverID, serverKey := newIdentity(t) |
|||
tr, err := libp2pwebtransport.New(serverKey, connGater, network.NullResourceManager) |
|||
require.NoError(t, err) |
|||
defer tr.(io.Closer).Close() |
|||
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) |
|||
require.NoError(t, err) |
|||
defer ln.Close() |
|||
|
|||
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) { |
|||
require.Equal(t, stripCertHashes(ln.Multiaddr()), addrs.LocalMultiaddr()) |
|||
require.NotEqual(t, stripCertHashes(ln.Multiaddr()), addrs.RemoteMultiaddr()) |
|||
}) |
|||
|
|||
_, key := newIdentity(t) |
|||
cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) |
|||
require.NoError(t, err) |
|||
defer cl.(io.Closer).Close() |
|||
_, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID) |
|||
require.EqualError(t, err, "received status 403") |
|||
} |
|||
|
|||
func TestConnectionGaterInterceptSecured(t *testing.T) { |
|||
ctrl := gomock.NewController(t) |
|||
defer ctrl.Finish() |
|||
connGater := NewMockConnectionGater(ctrl) |
|||
|
|||
serverID, serverKey := newIdentity(t) |
|||
tr, err := libp2pwebtransport.New(serverKey, connGater, network.NullResourceManager) |
|||
require.NoError(t, err) |
|||
defer tr.(io.Closer).Close() |
|||
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) |
|||
require.NoError(t, err) |
|||
defer ln.Close() |
|||
|
|||
clientID, key := newIdentity(t) |
|||
cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) |
|||
require.NoError(t, err) |
|||
defer cl.(io.Closer).Close() |
|||
|
|||
connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true) |
|||
connGater.EXPECT().InterceptSecured(network.DirInbound, clientID, gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) { |
|||
require.Equal(t, stripCertHashes(ln.Multiaddr()), addrs.LocalMultiaddr()) |
|||
require.NotEqual(t, stripCertHashes(ln.Multiaddr()), addrs.RemoteMultiaddr()) |
|||
}) |
|||
// The handshake will complete, but the server will immediately close the connection.
|
|||
conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID) |
|||
require.NoError(t, err) |
|||
defer conn.Close() |
|||
done := make(chan struct{}) |
|||
go func() { |
|||
defer close(done) |
|||
_, err = conn.AcceptStream() |
|||
require.Error(t, err) |
|||
}() |
|||
select { |
|||
case <-done: |
|||
case <-time.After(5 * time.Second): |
|||
t.Fatal("timeout") |
|||
} |
|||
} |
|||
|
|||
func getTLSConf(t *testing.T, ip net.IP, start, end time.Time) *tls.Config { |
|||
t.Helper() |
|||
certTempl := &x509.Certificate{ |
|||
SerialNumber: big.NewInt(1234), |
|||
Subject: pkix.Name{Organization: []string{"webtransport"}}, |
|||
NotBefore: start, |
|||
NotAfter: end, |
|||
IsCA: true, |
|||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, |
|||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, |
|||
BasicConstraintsValid: true, |
|||
IPAddresses: []net.IP{ip}, |
|||
} |
|||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) |
|||
require.NoError(t, err) |
|||
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &priv.PublicKey, priv) |
|||
require.NoError(t, err) |
|||
cert, err := x509.ParseCertificate(caBytes) |
|||
require.NoError(t, err) |
|||
return &tls.Config{ |
|||
Certificates: []tls.Certificate{{ |
|||
Certificate: [][]byte{cert.Raw}, |
|||
PrivateKey: priv, |
|||
Leaf: cert, |
|||
}}, |
|||
} |
|||
} |
|||
|
|||
func TestStaticTLSConf(t *testing.T) { |
|||
tlsConf := getTLSConf(t, net.ParseIP("127.0.0.1"), time.Now(), time.Now().Add(365*24*time.Hour)) |
|||
|
|||
serverID, serverKey := newIdentity(t) |
|||
tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager, libp2pwebtransport.WithTLSConfig(tlsConf)) |
|||
require.NoError(t, err) |
|||
defer tr.(io.Closer).Close() |
|||
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) |
|||
require.NoError(t, err) |
|||
defer ln.Close() |
|||
require.Empty(t, extractCertHashes(ln.Multiaddr()), "listener address shouldn't contain any certhash") |
|||
|
|||
t.Run("fails when the certificate is invalid", func(t *testing.T) { |
|||
_, key := newIdentity(t) |
|||
cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) |
|||
require.NoError(t, err) |
|||
defer cl.(io.Closer).Close() |
|||
|
|||
_, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID) |
|||
require.Error(t, err) |
|||
if !strings.Contains(err.Error(), "certificate is not trusted") && |
|||
!strings.Contains(err.Error(), "certificate signed by unknown authority") { |
|||
t.Fatalf("expected a certificate error, got %+v", err) |
|||
} |
|||
}) |
|||
|
|||
t.Run("fails when dialing with a wrong certhash", func(t *testing.T) { |
|||
_, key := newIdentity(t) |
|||
cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) |
|||
require.NoError(t, err) |
|||
defer cl.(io.Closer).Close() |
|||
|
|||
addr := ln.Multiaddr().Encapsulate(getCerthashComponent(t, []byte("foo"))) |
|||
_, err = cl.Dial(context.Background(), addr, serverID) |
|||
require.Error(t, err) |
|||
require.Contains(t, err.Error(), "cert hash not found") |
|||
}) |
|||
|
|||
t.Run("accepts a valid TLS certificate", func(t *testing.T) { |
|||
_, key := newIdentity(t) |
|||
store := x509.NewCertPool() |
|||
store.AddCert(tlsConf.Certificates[0].Leaf) |
|||
tlsConf := &tls.Config{RootCAs: store} |
|||
cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager, libp2pwebtransport.WithTLSClientConfig(tlsConf)) |
|||
require.NoError(t, err) |
|||
defer cl.(io.Closer).Close() |
|||
|
|||
require.True(t, cl.CanDial(ln.Multiaddr())) |
|||
conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID) |
|||
require.NoError(t, err) |
|||
defer conn.Close() |
|||
}) |
|||
} |
|||
|
|||
func TestAcceptQueueFilledUp(t *testing.T) { |
|||
serverID, serverKey := newIdentity(t) |
|||
tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager) |
|||
require.NoError(t, err) |
|||
defer tr.(io.Closer).Close() |
|||
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport")) |
|||
require.NoError(t, err) |
|||
defer ln.Close() |
|||
|
|||
newConn := func() (tpt.CapableConn, error) { |
|||
t.Helper() |
|||
_, key := newIdentity(t) |
|||
cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager) |
|||
require.NoError(t, err) |
|||
defer cl.(io.Closer).Close() |
|||
return cl.Dial(context.Background(), ln.Multiaddr(), serverID) |
|||
} |
|||
|
|||
for i := 0; i < 16; i++ { |
|||
conn, err := newConn() |
|||
require.NoError(t, err) |
|||
defer conn.Close() |
|||
} |
|||
|
|||
conn, err := newConn() |
|||
if err == nil { |
|||
_, err = conn.AcceptStream() |
|||
} |
|||
require.Error(t, err) |
|||
} |
Loading…
Reference in new issue