Browse Source

use the resource manager when dialing

pull/1737/head
Marten Seemann 2 years ago
parent
commit
ed5a2f58da
  1. 26
      p2p/transport/webtransport/transport.go
  2. 45
      p2p/transport/webtransport/transport_test.go

26
p2p/transport/webtransport/transport.go

@ -11,6 +11,7 @@ import (
pb "github.com/marten-seemann/go-libp2p-webtransport/pb"
ic "github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
tpt "github.com/libp2p/go-libp2p-core/transport"
@ -37,6 +38,8 @@ type transport struct {
dialer webtransport.Dialer
rcmgr network.ResourceManager
listenOnce sync.Once
listenOnceErr error
certManager *certManager
@ -47,7 +50,7 @@ type transport struct {
var _ tpt.Transport = &transport{}
var _ io.Closer = &transport{}
func New(key ic.PrivKey) (tpt.Transport, error) {
func New(key ic.PrivKey, rcmgr network.ResourceManager) (tpt.Transport, error) {
id, err := peer.IDFromPrivateKey(key)
if err != nil {
return nil, err
@ -55,6 +58,7 @@ func New(key ic.PrivKey) (tpt.Transport, error) {
t := &transport{
pid: id,
privKey: key,
rcmgr: rcmgr,
dialer: webtransport.Dialer{
RoundTripper: &http3.RoundTripper{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // TODO: verify certificate,
@ -70,6 +74,26 @@ func New(key ic.PrivKey) (tpt.Transport, error) {
}
func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) {
scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false, raddr)
if err != nil {
log.Debugw("resource manager blocked outgoing connection", "peer", p, "addr", raddr, "error", err)
return nil, err
}
if err := scope.SetPeer(p); err != nil {
log.Debugw("resource manager blocked outgoing connection for peer", "peer", p, "addr", raddr, "error", err)
scope.Done()
return nil, err
}
conn, err := t.dial(ctx, raddr, p)
if err != nil {
scope.Done()
return nil, err
}
return conn, nil
}
func (t *transport) dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) {
_, addr, err := manet.DialArgs(raddr)
if err != nil {
return nil, err

45
p2p/transport/webtransport/transport_test.go

@ -4,16 +4,20 @@ import (
"context"
"crypto/rand"
"crypto/sha256"
"errors"
"fmt"
"io"
"strings"
"net"
"testing"
libp2pwebtransport "github.com/marten-seemann/go-libp2p-webtransport"
ic "github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/golang/mock/gomock"
mocknetwork "github.com/libp2p/go-libp2p-testing/mocks/network"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/multiformats/go-multibase"
@ -51,7 +55,7 @@ func extractCertHashes(t *testing.T, addr ma.Multiaddr) []string {
func TestTransport(t *testing.T) {
serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey)
tr, err := libp2pwebtransport.New(serverKey, network.NullResourceManager)
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport"))
@ -61,7 +65,7 @@ func TestTransport(t *testing.T) {
addrChan := make(chan ma.Multiaddr)
go func() {
_, clientKey := newIdentity(t)
tr2, err := libp2pwebtransport.New(clientKey)
tr2, err := libp2pwebtransport.New(clientKey, network.NullResourceManager)
require.NoError(t, err)
defer tr2.(io.Closer).Close()
@ -76,7 +80,8 @@ func TestTransport(t *testing.T) {
// check RemoteMultiaddr
_, addr, err := manet.DialArgs(ln.Multiaddr())
require.NoError(t, err)
port := strings.Split(addr, ":")[1]
_, port, err := net.SplitHostPort(addr)
require.NoError(t, err)
require.Equal(t, ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/%s/quic/webtransport", port)), conn.RemoteMultiaddr())
addrChan <- conn.RemoteMultiaddr()
}()
@ -93,7 +98,7 @@ func TestTransport(t *testing.T) {
func TestHashVerification(t *testing.T) {
serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey)
tr, err := libp2pwebtransport.New(serverKey, network.NullResourceManager)
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport"))
@ -106,7 +111,7 @@ func TestHashVerification(t *testing.T) {
}()
_, clientKey := newIdentity(t)
tr2, err := libp2pwebtransport.New(clientKey)
tr2, err := libp2pwebtransport.New(clientKey, network.NullResourceManager)
require.NoError(t, err)
defer tr2.(io.Closer).Close()
@ -152,7 +157,7 @@ func TestCanDial(t *testing.T) {
}
_, key := newIdentity(t)
tr, err := libp2pwebtransport.New(key)
tr, err := libp2pwebtransport.New(key, network.NullResourceManager)
require.NoError(t, err)
defer tr.(io.Closer).Close()
@ -178,7 +183,7 @@ func TestListenAddrValidity(t *testing.T) {
}
_, key := newIdentity(t)
tr, err := libp2pwebtransport.New(key)
tr, err := libp2pwebtransport.New(key, network.NullResourceManager)
require.NoError(t, err)
defer tr.(io.Closer).Close()
@ -195,7 +200,7 @@ func TestListenAddrValidity(t *testing.T) {
func TestListenerAddrs(t *testing.T) {
_, key := newIdentity(t)
tr, err := libp2pwebtransport.New(key)
tr, err := libp2pwebtransport.New(key, network.NullResourceManager)
require.NoError(t, err)
defer tr.(io.Closer).Close()
@ -208,3 +213,25 @@ func TestListenerAddrs(t *testing.T) {
hashes2 := extractCertHashes(t, ln2.Multiaddr())
require.Equal(t, hashes1, hashes2)
}
func TestResourceManagerDialing(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
rcmgr := mocknetwork.NewMockResourceManager(ctrl)
addr := ma.StringCast("/ip4/9.8.7.6/udp/1234/quic/webtransport")
p := peer.ID("foobar")
_, key := newIdentity(t)
tr, err := libp2pwebtransport.New(key, rcmgr)
require.NoError(t, err)
defer tr.(io.Closer).Close()
scope := mocknetwork.NewMockConnManagementScope(ctrl)
rcmgr.EXPECT().OpenConnection(network.DirOutbound, false, addr).Return(scope, nil)
scope.EXPECT().SetPeer(p).Return(errors.New("denied"))
scope.EXPECT().Done()
_, err = tr.Dial(context.Background(), addr, p)
require.EqualError(t, err, "denied")
}

Loading…
Cancel
Save