Browse Source

webrtc: setup datachannel handlers before connecting to a peer (#2716)

If done after connecting to a peer, there's a small window of time
when datachannels created by the peer may cause a memory leak.
pull/2717/head
Sukun 9 months ago
committed by GitHub
parent
commit
473a5e9def
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 50
      p2p/transport/webrtc/connection.go
  2. 75
      p2p/transport/webrtc/listener.go
  3. 122
      p2p/transport/webrtc/transport.go

50
p2p/transport/webrtc/connection.go

@ -13,14 +13,10 @@ import (
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
tpt "github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/transport/webrtc/pb"
"github.com/libp2p/go-msgio"
ma "github.com/multiformats/go-multiaddr"
"github.com/pion/datachannel"
"github.com/pion/webrtc/v3"
"google.golang.org/protobuf/proto"
)
var _ tpt.CapableConn = &connection{}
@ -77,6 +73,7 @@ func newConnection(
remotePeer peer.ID,
remoteKey ic.PubKey,
remoteMultiaddr ma.Multiaddr,
incomingDataChannels chan dataChannel,
) (*connection, error) {
ctx, cancel := context.WithCancel(context.Background())
c := &connection{
@ -94,7 +91,7 @@ func newConnection(
cancel: cancel,
streams: make(map[uint16]*stream),
acceptQueue: make(chan dataChannel, maxAcceptQueueLen),
acceptQueue: incomingDataChannels,
}
switch direction {
case network.DirInbound:
@ -105,24 +102,6 @@ func newConnection(
}
pc.OnConnectionStateChange(c.onConnectionStateChange)
pc.OnDataChannel(func(dc *webrtc.DataChannel) {
dc.OnOpen(func() {
rwc, err := dc.Detach()
if err != nil {
log.Warnf("could not detach datachannel: id: %d", *dc.ID())
return
}
select {
case c.acceptQueue <- dataChannel{rwc, dc}:
default:
log.Warnf("connection busy, rejecting stream")
b, _ := proto.Marshal(&pb.Message{Flag: pb.Message_RESET.Enum()})
w := msgio.NewWriter(rwc)
w.WriteMsg(b)
rwc.Close()
}
})
})
return c, nil
}
@ -274,28 +253,3 @@ func (c *connection) detachChannel(ctx context.Context, dc *webrtc.DataChannel)
return rwc, err
}
}
// A note on these setters and why they are needed:
//
// The connection object sets up receiving datachannels (streams) from the remote peer.
// Please consider the XX noise handshake pattern from a peer A to peer B as described at:
// https://noiseexplorer.com/patterns/XX/
//
// The initiator A completes the noise handshake before B.
// This would allow A to create new datachannels before B has set up the callbacks to process incoming datachannels.
// This would create a situation where A has successfully created a stream but B is not aware of it.
// Moving the construction of the connection object before the noise handshake eliminates this issue,
// as callbacks have been set up for both peers.
//
// This could lead to a case where streams are created during the noise handshake,
// and the handshake fails. In this case, we would close the underlying peerconnection.
// only used during connection setup
func (c *connection) setRemotePeer(id peer.ID) {
c.remotePeer = id
}
// only used during connection setup
func (c *connection) setRemotePublicKey(key ic.PubKey) {
c.remoteKey = key
}

75
p2p/transport/webrtc/listener.go

@ -184,11 +184,11 @@ func (l *listener) setupConnection(
ctx context.Context, scope network.ConnManagementScope,
remoteMultiaddr ma.Multiaddr, candidate udpmux.Candidate,
) (tConn tpt.CapableConn, err error) {
var pc *webrtc.PeerConnection
var w webRTCConnection
defer func() {
if err != nil {
if pc != nil {
_ = pc.Close()
if w.PeerConnection != nil {
_ = w.PeerConnection.Close()
}
if tConn != nil {
_ = tConn.Close()
@ -224,34 +224,24 @@ func (l *listener) setupConnection(
)
settingEngine.DetachDataChannels()
api := webrtc.NewAPI(webrtc.WithSettingEngine(settingEngine))
pc, err = api.NewPeerConnection(l.config)
w, err = newWebRTCConnection(settingEngine, l.config)
if err != nil {
return nil, err
return nil, fmt.Errorf("instantiating peer connection failed: %w", err)
}
negotiated, id := handshakeChannelNegotiated, handshakeChannelID
rawDatachannel, err := pc.CreateDataChannel("", &webrtc.DataChannelInit{
Negotiated: &negotiated,
ID: &id,
})
if err != nil {
return nil, err
}
errC := addOnConnectionStateChangeCallback(pc)
errC := addOnConnectionStateChangeCallback(w.PeerConnection)
// Infer the client SDP from the incoming STUN message by setting the ice-ufrag.
if err := pc.SetRemoteDescription(webrtc.SessionDescription{
if err := w.PeerConnection.SetRemoteDescription(webrtc.SessionDescription{
SDP: createClientSDP(candidate.Addr, candidate.Ufrag),
Type: webrtc.SDPTypeOffer,
}); err != nil {
return nil, err
}
answer, err := pc.CreateAnswer(nil)
answer, err := w.PeerConnection.CreateAnswer(nil)
if err != nil {
return nil, err
}
if err := pc.SetLocalDescription(answer); err != nil {
if err := w.PeerConnection.SetLocalDescription(answer); err != nil {
return nil, err
}
@ -264,34 +254,14 @@ func (l *listener) setupConnection(
}
}
rwc, err := detachHandshakeDataChannel(ctx, rawDatachannel)
if err != nil {
return nil, err
}
localMultiaddrWithoutCerthash, _ := ma.SplitFunc(l.localMultiaddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_CERTHASH })
handshakeChannel := newStream(rawDatachannel, rwc, func() {})
// The connection is instantiated before performing the Noise handshake. This is
// to handle the case where the remote is faster and attempts to initiate a stream
// before the ondatachannel callback can be set.
conn, err := newConnection(
network.DirInbound,
pc,
l.transport,
scope,
l.transport.localPeerId,
localMultiaddrWithoutCerthash,
"", // remotePeer
nil, // remoteKey
remoteMultiaddr,
)
// Run the noise handshake.
rwc, err := detachHandshakeDataChannel(ctx, w.HandshakeDataChannel)
if err != nil {
return nil, err
}
handshakeChannel := newStream(w.HandshakeDataChannel, rwc, func() {})
// we do not yet know A's peer ID so accept any inbound
remotePubKey, err := l.transport.noiseHandshake(ctx, pc, handshakeChannel, "", crypto.SHA256, true)
remotePubKey, err := l.transport.noiseHandshake(ctx, w.PeerConnection, handshakeChannel, "", crypto.SHA256, true)
if err != nil {
return nil, err
}
@ -299,14 +269,27 @@ func (l *listener) setupConnection(
if err != nil {
return nil, err
}
// earliest point where we know the remote's peerID
if err := scope.SetPeer(remotePeer); err != nil {
return nil, err
}
conn.setRemotePeer(remotePeer)
conn.setRemotePublicKey(remotePubKey)
localMultiaddrWithoutCerthash, _ := ma.SplitFunc(l.localMultiaddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_CERTHASH })
conn, err := newConnection(
network.DirInbound,
w.PeerConnection,
l.transport,
scope,
l.transport.localPeerId,
localMultiaddrWithoutCerthash,
remotePeer,
remotePubKey,
remoteMultiaddr,
w.IncomingDataChannels,
)
if err != nil {
return nil, err
}
return conn, err
}

122
p2p/transport/webrtc/transport.go

@ -26,6 +26,7 @@ import (
"time"
mrand "golang.org/x/exp/rand"
"google.golang.org/protobuf/proto"
"github.com/libp2p/go-libp2p/core/connmgr"
ic "github.com/libp2p/go-libp2p/core/crypto"
@ -35,6 +36,8 @@ import (
"github.com/libp2p/go-libp2p/core/sec"
tpt "github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/security/noise"
"github.com/libp2p/go-libp2p/p2p/transport/webrtc/pb"
"github.com/libp2p/go-msgio"
logging "github.com/ipfs/go-log/v2"
ma "github.com/multiformats/go-multiaddr"
@ -259,11 +262,11 @@ func (t *WebRTCTransport) Dial(ctx context.Context, remoteMultiaddr ma.Multiaddr
}
func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagementScope, remoteMultiaddr ma.Multiaddr, p peer.ID) (tConn tpt.CapableConn, err error) {
var pc *webrtc.PeerConnection
var w webRTCConnection
defer func() {
if err != nil {
if pc != nil {
_ = pc.Close()
if w.PeerConnection != nil {
_ = w.PeerConnection.Close()
}
if tConn != nil {
_ = tConn.Close()
@ -319,32 +322,20 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement
// it will not connect to anything.
settingEngine.SetIncludeLoopbackCandidate(true)
api := webrtc.NewAPI(webrtc.WithSettingEngine(settingEngine))
pc, err = api.NewPeerConnection(t.webrtcConfig)
w, err = newWebRTCConnection(settingEngine, t.webrtcConfig)
if err != nil {
return nil, fmt.Errorf("instantiate peerconnection: %w", err)
return nil, fmt.Errorf("instantiating peer connection failed: %w", err)
}
errC := addOnConnectionStateChangeCallback(pc)
// We need to set negotiated = true for this channel on both
// the client and server to avoid DCEP errors.
negotiated, id := handshakeChannelNegotiated, handshakeChannelID
rawHandshakeChannel, err := pc.CreateDataChannel("", &webrtc.DataChannelInit{
Negotiated: &negotiated,
ID: &id,
})
if err != nil {
return nil, fmt.Errorf("create datachannel: %w", err)
}
errC := addOnConnectionStateChangeCallback(w.PeerConnection)
// do offer-answer exchange
offer, err := pc.CreateOffer(nil)
offer, err := w.PeerConnection.CreateOffer(nil)
if err != nil {
return nil, fmt.Errorf("create offer: %w", err)
}
err = pc.SetLocalDescription(offer)
err = w.PeerConnection.SetLocalDescription(offer)
if err != nil {
return nil, fmt.Errorf("set local description: %w", err)
}
@ -355,7 +346,7 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement
}
answer := webrtc.SessionDescription{SDP: answerSDPString, Type: webrtc.SDPTypeAnswer}
err = pc.SetRemoteDescription(answer)
err = w.PeerConnection.SetRemoteDescription(answer)
if err != nil {
return nil, fmt.Errorf("set remote description: %w", err)
}
@ -370,55 +361,52 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement
return nil, errors.New("peerconnection opening timed out")
}
detached, err := detachHandshakeDataChannel(ctx, rawHandshakeChannel)
// We are connected, run the noise handshake
detached, err := detachHandshakeDataChannel(ctx, w.HandshakeDataChannel)
if err != nil {
return nil, err
}
channel := newStream(w.HandshakeDataChannel, detached, func() {})
remotePubKey, err := t.noiseHandshake(ctx, w.PeerConnection, channel, p, remoteHashFunction, false)
if err != nil {
return nil, err
}
// set the local address from the candidate pair
cp, err := rawHandshakeChannel.Transport().Transport().ICETransport().GetSelectedCandidatePair()
// Setup local and remote address for the connection
cp, err := w.HandshakeDataChannel.Transport().Transport().ICETransport().GetSelectedCandidatePair()
if cp == nil {
return nil, errors.New("ice connection did not have selected candidate pair: nil result")
}
if err != nil {
return nil, fmt.Errorf("ice connection did not have selected candidate pair: error: %w", err)
}
channel := newStream(rawHandshakeChannel, detached, func() {})
// the local address of the selected candidate pair should be the
// local address for the connection, since different datachannels
// are multiplexed over the same SCTP connection
// the local address of the selected candidate pair should be the local address for the connection
localAddr, err := manet.FromNetAddr(&net.UDPAddr{IP: net.ParseIP(cp.Local.Address), Port: int(cp.Local.Port)})
if err != nil {
return nil, err
}
remoteMultiaddrWithoutCerthash, _ := ma.SplitFunc(remoteMultiaddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_CERTHASH })
// we can only know the remote public key after the noise handshake,
// but need to set up the callbacks on the peerconnection
conn, err := newConnection(
network.DirOutbound,
pc,
w.PeerConnection,
t,
scope,
t.localPeerId,
localAddr,
p,
nil,
remotePubKey,
remoteMultiaddrWithoutCerthash,
w.IncomingDataChannels,
)
if err != nil {
return nil, err
}
remotePubKey, err := t.noiseHandshake(ctx, pc, channel, p, remoteHashFunction, false)
if err != nil {
return nil, err
}
if t.gater != nil && !t.gater.InterceptSecured(network.DirOutbound, p, conn) {
return nil, fmt.Errorf("secured connection gated")
}
conn.setRemotePublicKey(remotePubKey)
return conn, nil
}
@ -555,3 +543,59 @@ func detachHandshakeDataChannel(ctx context.Context, dc *webrtc.DataChannel) (da
return nil, ctx.Err()
}
}
// webRTCConnection holds the webrtc.PeerConnection with the handshake channel and the queue for
// incoming data channels created by the peer.
//
// When creating a webrtc.PeerConnection, It is important to set the OnDataChannel handler upfront
// before connecting with the peer. If the handler's set up after connecting with the peer, there's
// a small window of time where datachannels created by the peer may not surface to us and cause a
// memory leak.
type webRTCConnection struct {
PeerConnection *webrtc.PeerConnection
HandshakeDataChannel *webrtc.DataChannel
IncomingDataChannels chan dataChannel
}
func newWebRTCConnection(settings webrtc.SettingEngine, config webrtc.Configuration) (webRTCConnection, error) {
api := webrtc.NewAPI(webrtc.WithSettingEngine(settings))
pc, err := api.NewPeerConnection(config)
if err != nil {
return webRTCConnection{}, fmt.Errorf("failed to create peer connection: %w", err)
}
negotiated, id := handshakeChannelNegotiated, handshakeChannelID
handshakeDataChannel, err := pc.CreateDataChannel("", &webrtc.DataChannelInit{
Negotiated: &negotiated,
ID: &id,
})
if err != nil {
pc.Close()
return webRTCConnection{}, fmt.Errorf("failed to create handshake channel: %w", err)
}
incomingDataChannels := make(chan dataChannel, maxAcceptQueueLen)
pc.OnDataChannel(func(dc *webrtc.DataChannel) {
dc.OnOpen(func() {
rwc, err := dc.Detach()
if err != nil {
log.Warnf("could not detach datachannel: id: %d", *dc.ID())
return
}
select {
case incomingDataChannels <- dataChannel{rwc, dc}:
default:
log.Warnf("connection busy, rejecting stream")
b, _ := proto.Marshal(&pb.Message{Flag: pb.Message_RESET.Enum()})
w := msgio.NewWriter(rwc)
w.WriteMsg(b)
rwc.Close()
}
})
})
return webRTCConnection{
PeerConnection: pc,
HandshakeDataChannel: handshakeDataChannel,
IncomingDataChannels: incomingDataChannels,
}, nil
}

Loading…
Cancel
Save