Browse Source

implement InterceptAccept connection gating

pull/1737/head
Marten Seemann 2 years ago
parent
commit
a508e94004
  1. 14
      p2p/transport/webtransport/listener.go
  2. 17
      p2p/transport/webtransport/transport.go
  3. 46
      p2p/transport/webtransport/transport_test.go

14
p2p/transport/webtransport/listener.go

@ -9,6 +9,7 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/libp2p/go-libp2p-core/connmgr"
"github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/network"
tpt "github.com/libp2p/go-libp2p-core/transport" tpt "github.com/libp2p/go-libp2p-core/transport"
@ -30,6 +31,7 @@ type listener struct {
noise *noise.Transport noise *noise.Transport
certManager *certManager certManager *certManager
rcmgr network.ResourceManager rcmgr network.ResourceManager
gater connmgr.ConnectionGater
server webtransport.Server server webtransport.Server
@ -46,7 +48,7 @@ type listener struct {
var _ tpt.Listener = &listener{} var _ tpt.Listener = &listener{}
func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Transport, certManager *certManager, rcmgr network.ResourceManager) (tpt.Listener, error) { func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Transport, certManager *certManager, gater connmgr.ConnectionGater, rcmgr network.ResourceManager) (tpt.Listener, error) {
network, addr, err := manet.DialArgs(laddr) network, addr, err := manet.DialArgs(laddr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -68,6 +70,7 @@ func newListener(laddr ma.Multiaddr, transport tpt.Transport, noise *noise.Trans
noise: noise, noise: noise,
certManager: certManager, certManager: certManager,
rcmgr: rcmgr, rcmgr: rcmgr,
gater: gater,
queue: make(chan tpt.CapableConn, queueLen), queue: make(chan tpt.CapableConn, queueLen),
serverClosed: make(chan struct{}), serverClosed: make(chan struct{}),
addr: udpConn.LocalAddr(), addr: udpConn.LocalAddr(),
@ -106,6 +109,10 @@ func (l *listener) httpHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
return return
} }
if l.gater != nil && !l.gater.InterceptAccept(&connMultiaddrs{local: l.multiaddr, remote: remoteMultiaddr}) {
w.WriteHeader(http.StatusForbidden)
return
}
connScope, err := l.rcmgr.OpenConnection(network.DirInbound, false, remoteMultiaddr) connScope, err := l.rcmgr.OpenConnection(network.DirInbound, false, remoteMultiaddr)
if err != nil { if err != nil {
@ -174,9 +181,8 @@ func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (c
} }
return &connSecurityMultiaddrsImpl{ return &connSecurityMultiaddrsImpl{
ConnSecurity: c, ConnSecurity: c,
local: local, ConnMultiaddrs: &connMultiaddrs{local: local, remote: remote},
remote: remote,
}, nil }, nil
} }

17
p2p/transport/webtransport/transport.go

@ -39,13 +39,17 @@ type connSecurityMultiaddrs interface {
type connSecurityMultiaddrsImpl struct { type connSecurityMultiaddrsImpl struct {
network.ConnSecurity network.ConnSecurity
network.ConnMultiaddrs
}
type connMultiaddrs struct {
local, remote ma.Multiaddr local, remote ma.Multiaddr
} }
var _ connSecurityMultiaddrs = &connSecurityMultiaddrsImpl{} var _ network.ConnMultiaddrs = &connMultiaddrs{}
func (c *connSecurityMultiaddrsImpl) LocalMultiaddr() ma.Multiaddr { return c.local } func (c *connMultiaddrs) LocalMultiaddr() ma.Multiaddr { return c.local }
func (c *connSecurityMultiaddrsImpl) RemoteMultiaddr() ma.Multiaddr { return c.remote } func (c *connMultiaddrs) RemoteMultiaddr() ma.Multiaddr { return c.remote }
type transport struct { type transport struct {
privKey ic.PrivKey privKey ic.PrivKey
@ -176,9 +180,8 @@ func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p p
return nil, err return nil, err
} }
return &connSecurityMultiaddrsImpl{ return &connSecurityMultiaddrsImpl{
ConnSecurity: c, ConnSecurity: c,
local: local, ConnMultiaddrs: &connMultiaddrs{local: local, remote: remote},
remote: remote,
}, nil }, nil
} }
@ -225,7 +228,7 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) {
if t.listenOnceErr != nil { if t.listenOnceErr != nil {
return nil, t.listenOnceErr return nil, t.listenOnceErr
} }
return newListener(laddr, t, t.noise, t.certManager, t.rcmgr) return newListener(laddr, t, t.noise, t.certManager, t.gater, t.rcmgr)
} }
func (t *transport) Protocols() []int { func (t *transport) Protocols() []int {

46
p2p/transport/webtransport/transport_test.go

@ -54,6 +54,16 @@ func extractCertHashes(addr ma.Multiaddr) []string {
return certHashesStr return certHashesStr
} }
func stripCertHashes(addr ma.Multiaddr) ma.Multiaddr {
for {
_, err := addr.ValueForProtocol(ma.P_CERTHASH)
if err != nil {
return addr
}
addr, _ = ma.SplitLast(addr)
}
}
func TestTransport(t *testing.T) { func TestTransport(t *testing.T) {
serverID, serverKey := newIdentity(t) serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager) tr, err := libp2pwebtransport.New(serverKey, nil, network.NullResourceManager)
@ -323,15 +333,7 @@ func TestConnectionGaterDialing(t *testing.T) {
defer ln.Close() defer ln.Close()
connGater.EXPECT().InterceptSecured(network.DirOutbound, serverID, gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) { connGater.EXPECT().InterceptSecured(network.DirOutbound, serverID, gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
expected := ln.Multiaddr() require.Equal(t, stripCertHashes(ln.Multiaddr()), addrs.RemoteMultiaddr())
for {
_, err := expected.ValueForProtocol(ma.P_CERTHASH)
if err != nil {
break
}
expected, _ = ma.SplitLast(expected)
}
require.Equal(t, expected, addrs.RemoteMultiaddr())
}) })
_, key := newIdentity(t) _, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, connGater, network.NullResourceManager) cl, err := libp2pwebtransport.New(key, connGater, network.NullResourceManager)
@ -340,3 +342,29 @@ func TestConnectionGaterDialing(t *testing.T) {
_, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID) _, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID)
require.EqualError(t, err, "secured connection gated") require.EqualError(t, err, "secured connection gated")
} }
func TestConnectionGaterInterceptAccept(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
connGater := NewMockConnectionGater(ctrl)
serverID, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, connGater, network.NullResourceManager)
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport"))
require.NoError(t, err)
defer ln.Close()
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
require.Equal(t, stripCertHashes(ln.Multiaddr()), addrs.LocalMultiaddr())
require.NotEqual(t, stripCertHashes(ln.Multiaddr()), addrs.RemoteMultiaddr())
})
_, key := newIdentity(t)
cl, err := libp2pwebtransport.New(key, nil, network.NullResourceManager)
require.NoError(t, err)
defer cl.(io.Closer).Close()
_, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID)
require.EqualError(t, err, "received status 403")
}

Loading…
Cancel
Save