Browse Source

refactor the conn constructor

pull/1737/head
Marten Seemann 2 years ago
parent
commit
ecc1eff49d
  1. 64
      p2p/transport/webtransport/conn.go
  2. 32
      p2p/transport/webtransport/listener.go
  3. 47
      p2p/transport/webtransport/transport.go

64
p2p/transport/webtransport/conn.go

@ -2,60 +2,32 @@ package libp2pwebtransport
import (
"context"
"fmt"
ic "github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
tpt "github.com/libp2p/go-libp2p-core/transport"
"github.com/marten-seemann/webtransport-go"
ma "github.com/multiformats/go-multiaddr"
)
type conn struct {
connSecurityMultiaddrs
transport tpt.Transport
session *webtransport.Session
localPeer, remotePeer peer.ID
local, remote ma.Multiaddr
privKey ic.PrivKey
remotePubKey ic.PubKey
scope network.ConnScope
scope network.ConnScope
}
func newConn(tr tpt.Transport, sess *webtransport.Session, privKey ic.PrivKey, remotePubKey ic.PubKey, scope network.ConnScope) (*conn, error) {
localPeer, err := peer.IDFromPrivateKey(privKey)
if err != nil {
return nil, err
}
remotePeer, err := peer.IDFromPublicKey(remotePubKey)
if err != nil {
return nil, err
}
local, err := toWebtransportMultiaddr(sess.LocalAddr())
if err != nil {
return nil, fmt.Errorf("error determiniting local addr: %w", err)
}
remote, err := toWebtransportMultiaddr(sess.RemoteAddr())
if err != nil {
return nil, fmt.Errorf("error determiniting remote addr: %w", err)
}
var _ tpt.CapableConn = &conn{}
func newConn(tr tpt.Transport, sess *webtransport.Session, sconn connSecurityMultiaddrs, scope network.ConnScope) *conn {
return &conn{
transport: tr,
session: sess,
privKey: privKey,
localPeer: localPeer,
remotePeer: remotePeer,
remotePubKey: remotePubKey,
local: local,
remote: remote,
scope: scope,
}, nil
connSecurityMultiaddrs: sconn,
transport: tr,
session: sess,
scope: scope,
}
}
var _ tpt.CapableConn = &conn{}
func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
str, err := c.session.OpenStreamSync(ctx)
return &stream{str}, err
@ -66,13 +38,7 @@ func (c *conn) AcceptStream() (network.MuxedStream, error) {
return &stream{str}, err
}
func (c *conn) Close() error { return c.session.Close() }
func (c *conn) IsClosed() bool { return c.session.Context().Err() != nil }
func (c *conn) LocalPeer() peer.ID { return c.localPeer }
func (c *conn) LocalPrivateKey() ic.PrivKey { return c.privKey }
func (c *conn) RemotePeer() peer.ID { return c.remotePeer }
func (c *conn) RemotePublicKey() ic.PubKey { return c.remotePubKey }
func (c *conn) LocalMultiaddr() ma.Multiaddr { return c.local }
func (c *conn) RemoteMultiaddr() ma.Multiaddr { return c.remote }
func (c *conn) Scope() network.ConnScope { return c.scope }
func (c *conn) Transport() tpt.Transport { return c.transport }
func (c *conn) Close() error { return c.session.Close() }
func (c *conn) IsClosed() bool { return c.session.Context().Err() != nil }
func (c *conn) Scope() network.ConnScope { return c.scope }
func (c *conn) Transport() tpt.Transport { return c.transport }

32
p2p/transport/webtransport/listener.go

@ -4,6 +4,7 @@ import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"time"
@ -140,15 +141,8 @@ func (l *listener) httpHandler(w http.ResponseWriter, r *http.Request) {
return
}
c, err := newConn(l.transport, sess, sconn.LocalPrivateKey(), sconn.RemotePublicKey(), connScope)
if err != nil {
sess.Close()
connScope.Done()
return
}
// TODO: think about what happens when this channel fills up
l.queue <- c
l.queue <- newConn(l.transport, sess, sconn, connScope)
}
func (l *listener) Accept() (tpt.CapableConn, error) {
@ -160,12 +154,30 @@ func (l *listener) Accept() (tpt.CapableConn, error) {
}
}
func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (network.ConnSecurity, error) {
func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (connSecurityMultiaddrs, error) {
local, err := toWebtransportMultiaddr(sess.LocalAddr())
if err != nil {
return nil, fmt.Errorf("error determiniting local addr: %w", err)
}
remote, err := toWebtransportMultiaddr(sess.RemoteAddr())
if err != nil {
return nil, fmt.Errorf("error determiniting remote addr: %w", err)
}
str, err := sess.AcceptStream(ctx)
if err != nil {
return nil, err
}
return l.noise.SecureInbound(ctx, &webtransportStream{Stream: str, wsess: sess}, "")
c, err := l.noise.SecureInbound(ctx, &webtransportStream{Stream: str, wsess: sess}, "")
if err != nil {
return nil, err
}
return &connSecurityMultiaddrsImpl{
ConnSecurity: c,
local: local,
remote: remote,
}, nil
}
func (l *listener) Addr() net.Addr {

47
p2p/transport/webtransport/transport.go

@ -4,7 +4,6 @@ import (
"context"
"crypto/tls"
"fmt"
manet "github.com/multiformats/go-multiaddr/net"
"io"
"sync"
"time"
@ -22,6 +21,7 @@ import (
"github.com/lucas-clemente/quic-go/http3"
"github.com/marten-seemann/webtransport-go"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/multiformats/go-multihash"
)
@ -31,6 +31,21 @@ const webtransportHTTPEndpoint = "/.well-known/libp2p-webtransport"
const certValidity = 14 * 24 * time.Hour
type connSecurityMultiaddrs interface {
network.ConnMultiaddrs
network.ConnSecurity
}
type connSecurityMultiaddrsImpl struct {
network.ConnSecurity
local, remote ma.Multiaddr
}
var _ connSecurityMultiaddrs = &connSecurityMultiaddrsImpl{}
func (c *connSecurityMultiaddrsImpl) LocalMultiaddr() ma.Multiaddr { return c.local }
func (c *connSecurityMultiaddrsImpl) RemoteMultiaddr() ma.Multiaddr { return c.remote }
type transport struct {
privKey ic.PrivKey
pid peer.ID
@ -104,13 +119,8 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp
scope.Done()
return nil, err
}
c, err := newConn(t, sess, t.privKey, sconn.RemotePublicKey(), scope)
if err != nil {
sess.Close()
scope.Done()
return nil, err
}
return c, nil
return newConn(t, sess, sconn, scope), nil
}
func (t *transport) dial(ctx context.Context, addr string) (*webtransport.Session, error) {
@ -125,7 +135,16 @@ func (t *transport) dial(ctx context.Context, addr string) (*webtransport.Sessio
return sess, err
}
func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash) (network.ConnSecurity, error) {
func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash) (connSecurityMultiaddrs, error) {
local, err := toWebtransportMultiaddr(sess.LocalAddr())
if err != nil {
return nil, fmt.Errorf("error determiniting local addr: %w", err)
}
remote, err := toWebtransportMultiaddr(sess.RemoteAddr())
if err != nil {
return nil, fmt.Errorf("error determiniting remote addr: %w", err)
}
str, err := sess.OpenStreamSync(ctx)
if err != nil {
return nil, err
@ -144,7 +163,15 @@ func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p p
if err != nil {
return nil, fmt.Errorf("failed to marshal WebTransport protobuf: %w", err)
}
return t.noise.SecureOutboundWithEarlyData(ctx, &webtransportStream{Stream: str, wsess: sess}, p, msgBytes)
c, err := t.noise.SecureOutboundWithEarlyData(ctx, &webtransportStream{Stream: str, wsess: sess}, p, msgBytes)
if err != nil {
return nil, err
}
return &connSecurityMultiaddrsImpl{
ConnSecurity: c,
local: local,
remote: remote,
}, nil
}
func (t *transport) checkEarlyData(b []byte) error {

Loading…
Cancel
Save