From c7149b3be4b74a160c52271ccbda241d19e6219d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 9 Jul 2022 18:49:40 +0000 Subject: [PATCH] pass the connection scope to the connection (#11) --- p2p/transport/webtransport/conn.go | 15 +++---- p2p/transport/webtransport/listener.go | 29 +++++++------ p2p/transport/webtransport/transport.go | 45 +++++++++++++------- p2p/transport/webtransport/transport_test.go | 15 ++++++- 4 files changed, 63 insertions(+), 41 deletions(-) diff --git a/p2p/transport/webtransport/conn.go b/p2p/transport/webtransport/conn.go index 90ae154b8..3478d39db 100644 --- a/p2p/transport/webtransport/conn.go +++ b/p2p/transport/webtransport/conn.go @@ -21,9 +21,10 @@ type conn struct { local, remote ma.Multiaddr privKey ic.PrivKey remotePubKey ic.PubKey + scope network.ConnScope } -func newConn(tr tpt.Transport, wsess *webtransport.Session, privKey ic.PrivKey, remotePubKey ic.PubKey) (*conn, error) { +func newConn(tr tpt.Transport, wsess *webtransport.Session, privKey ic.PrivKey, remotePubKey ic.PubKey, scope network.ConnScope) (*conn, error) { localPeer, err := peer.IDFromPrivateKey(privKey) if err != nil { return nil, err @@ -49,6 +50,7 @@ func newConn(tr tpt.Transport, wsess *webtransport.Session, privKey ic.PrivKey, remotePubKey: remotePubKey, local: local, remote: remote, + scope: scope, }, nil } @@ -78,12 +80,5 @@ func (c *conn) RemotePeer() peer.ID { return c.remotePeer } func (c *conn) RemotePublicKey() ic.PubKey { return c.remotePubKey } func (c *conn) LocalMultiaddr() ma.Multiaddr { return c.local } func (c *conn) RemoteMultiaddr() ma.Multiaddr { return c.remote } - -func (c *conn) Scope() network.ConnScope { - // TODO implement me - panic("implement me") -} - -func (c *conn) Transport() tpt.Transport { - return c.transport -} +func (c *conn) Scope() network.ConnScope { return c.scope } +func (c *conn) Transport() tpt.Transport { return c.transport } diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index ef21c23ff..739c230ca 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -114,7 +114,7 @@ func (l *listener) httpHandler(w http.ResponseWriter, r *http.Request) { } // TODO: check ?type=multistream URL param - c, err := l.server.Upgrade(w, r) + sess, err := l.server.Upgrade(w, r) if err != nil { log.Debugw("upgrade failed", "error", err) // TODO: think about the status code to use here @@ -123,25 +123,32 @@ func (l *listener) httpHandler(w http.ResponseWriter, r *http.Request) { return } ctx, cancel := context.WithTimeout(l.ctx, handshakeTimeout) - conn, err := l.handshake(ctx, c) + sconn, err := l.handshake(ctx, sess) if err != nil { cancel() log.Debugw("handshake failed", "error", err) - c.Close() + sess.Close() connScope.Done() return } cancel() - if err := connScope.SetPeer(conn.RemotePeer()); err != nil { - log.Debugw("resource manager blocked incoming connection for peer", "peer", conn.RemotePeer(), "addr", r.RemoteAddr, "error", err) - conn.Close() + if err := connScope.SetPeer(sconn.RemotePeer()); err != nil { + log.Debugw("resource manager blocked incoming connection for peer", "peer", sconn.RemotePeer(), "addr", r.RemoteAddr, "error", err) + sess.Close() + connScope.Done() + return + } + + c, err := newConn(l.transport, sess, sconn.LocalPrivateKey(), sconn.RemotePublicKey(), connScope) + if err != nil { + sess.Close() connScope.Done() return } // TODO: think about what happens when this channel fills up - l.queue <- conn + l.queue <- c } func (l *listener) Accept() (tpt.CapableConn, error) { @@ -153,16 +160,12 @@ func (l *listener) Accept() (tpt.CapableConn, error) { } } -func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (tpt.CapableConn, error) { +func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (network.ConnSecurity, error) { str, err := sess.AcceptStream(ctx) if err != nil { return nil, err } - conn, err := l.noise.SecureInbound(ctx, &webtransportStream{Stream: str, wsess: sess}, "") - if err != nil { - return nil, err - } - return newConn(l.transport, sess, conn.LocalPrivateKey(), conn.RemotePublicKey()) + return l.noise.SecureInbound(ctx, &webtransportStream{Stream: str, wsess: sess}, "") } func (l *listener) Addr() net.Addr { diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 2d2ffcfb6..640221a9b 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "fmt" + manet "github.com/multiformats/go-multiaddr/net" "io" "sync" "time" @@ -21,7 +22,6 @@ import ( "github.com/lucas-clemente/quic-go/http3" "github.com/marten-seemann/webtransport-go" ma "github.com/multiformats/go-multiaddr" - manet "github.com/multiformats/go-multiaddr/net" "github.com/multiformats/go-multihash" ) @@ -73,6 +73,15 @@ func New(key ic.PrivKey, rcmgr network.ResourceManager) (tpt.Transport, error) { } func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { + _, addr, err := manet.DialArgs(raddr) + if err != nil { + return nil, err + } + certHashes, err := extractCertHashes(raddr) + if err != nil { + return nil, err + } + scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false, raddr) if err != nil { log.Debugw("resource manager blocked outgoing connection", "peer", p, "addr", raddr, "error", err) @@ -84,32 +93,40 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp return nil, err } - conn, err := t.dial(ctx, raddr, p) + sess, err := t.dial(ctx, addr) if err != nil { scope.Done() return nil, err } - return conn, nil -} - -func (t *transport) dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { - _, addr, err := manet.DialArgs(raddr) + sconn, err := t.upgrade(ctx, sess, p, certHashes) if err != nil { + sess.Close() + scope.Done() return nil, err } - url := fmt.Sprintf("https://%s%s", addr, webtransportHTTPEndpoint) - certHashes, err := extractCertHashes(raddr) + c, err := newConn(t, sess, t.privKey, sconn.RemotePublicKey(), scope) if err != nil { + sess.Close() + scope.Done() return nil, err } - rsp, wconn, err := t.dialer.Dial(ctx, url, nil) + return c, nil +} + +func (t *transport) dial(ctx context.Context, addr string) (*webtransport.Session, error) { + url := fmt.Sprintf("https://%s%s", addr, webtransportHTTPEndpoint) + rsp, sess, err := t.dialer.Dial(ctx, url, nil) if err != nil { return nil, err } if rsp.StatusCode < 200 || rsp.StatusCode > 299 { return nil, fmt.Errorf("invalid response status code: %d", rsp.StatusCode) } - str, err := wconn.OpenStreamSync(ctx) + return sess, err +} + +func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash) (network.ConnSecurity, error) { + str, err := sess.OpenStreamSync(ctx) if err != nil { return nil, err } @@ -127,11 +144,7 @@ func (t *transport) dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp if err != nil { return nil, fmt.Errorf("failed to marshal WebTransport protobuf: %w", err) } - sconn, err := t.noise.SecureOutboundWithEarlyData(ctx, &webtransportStream{Stream: str, wsess: wconn}, p, msgBytes) - if err != nil { - return nil, err - } - return newConn(t, wconn, t.privKey, sconn.RemotePublicKey()) + return t.noise.SecureOutboundWithEarlyData(ctx, &webtransportStream{Stream: str, wsess: sess}, p, msgBytes) } func (t *transport) checkEarlyData(b []byte) error { diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index 115c7f103..7ddcff9ff 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -9,6 +9,7 @@ import ( "io" "net" "testing" + "time" libp2pwebtransport "github.com/marten-seemann/go-libp2p-webtransport" @@ -287,7 +288,17 @@ func TestResourceManagerListening(t *testing.T) { // The handshake will complete, but the server will immediately close the connection. conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID) require.NoError(t, err) - _, err = conn.AcceptStream() - require.Error(t, err) + defer conn.Close() + done := make(chan struct{}) + go func() { + defer close(done) + _, err = conn.AcceptStream() + require.Error(t, err) + }() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } }) }