Browse Source

webrtc: close mux when closing listener (#2717)

There is currently a leak in the webrtc listener. When the listener
is closed the udp mux readloop just keeps running.
pull/2718/head
Sukun 8 months ago
committed by GitHub
parent
commit
6130b9c75c
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 47
      p2p/transport/webrtc/listener.go
  2. 2
      p2p/transport/webrtc/stream_test.go
  3. 78
      p2p/transport/webrtc/transport_test.go
  4. 3
      p2p/transport/webrtc/udpmux/mux.go

47
p2p/transport/webrtc/listener.go

@ -57,6 +57,7 @@ type listener struct {
// used to control the lifecycle of the listener
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
var _ tpt.Listener = &listener{}
@ -91,30 +92,27 @@ func newListener(transport *WebRTCTransport, laddr ma.Multiaddr, socket net.Pack
}
l.ctx, l.cancel = context.WithCancel(context.Background())
mux := udpmux.NewUDPMux(socket)
l.mux = mux
mux.Start()
l.mux = udpmux.NewUDPMux(socket)
l.mux.Start()
go l.listen()
l.wg.Add(1)
go func() {
defer l.wg.Done()
l.listen()
}()
return l, err
}
func (l *listener) listen() {
// Accepting a connection requires instantiating a peerconnection
// and a noise connection which is expensive. We therefore limit
// the number of in-flight connection requests. A connection
// is considered to be in flight from the instant it is handled
// until it is dequeued by a call to Accept, or errors out in some
// way.
inFlightQueueCh := make(chan struct{}, l.transport.maxInFlightConnections)
for i := uint32(0); i < l.transport.maxInFlightConnections; i++ {
inFlightQueueCh <- struct{}{}
}
// Accepting a connection requires instantiating a peerconnection and a noise connection
// which is expensive. We therefore limit the number of in-flight connection requests. A
// connection is considered to be in flight from the instant it is handled until it is
// dequeued by a call to Accept, or errors out in some way.
inFlightSemaphore := make(chan struct{}, l.transport.maxInFlightConnections)
for {
select {
case <-inFlightQueueCh:
case inFlightSemaphore <- struct{}{}:
case <-l.ctx.Done():
return
}
@ -128,7 +126,7 @@ func (l *listener) listen() {
}
go func() {
defer func() { inFlightQueueCh <- struct{}{} }() // free this spot once again
defer func() { <-inFlightSemaphore }()
ctx, cancel := context.WithTimeout(l.ctx, candidateSetupTimeout)
defer cancel()
@ -145,7 +143,7 @@ func (l *listener) listen() {
log.Warn("could not push connection: ctx done")
conn.Close()
case l.acceptQueue <- conn:
// acceptQueue is an unbuffered channel, so this block until the connection is accepted.
// acceptQueue is an unbuffered channel, so this blocks until the connection is accepted.
}
}()
}
@ -307,7 +305,18 @@ func (l *listener) Close() error {
select {
case <-l.ctx.Done():
default:
l.cancel()
}
l.cancel()
l.mux.Close()
l.wg.Wait()
loop:
for {
select {
case conn := <-l.acceptQueue:
conn.Close()
default:
break loop
}
}
return nil
}

2
p2p/transport/webrtc/stream_test.go

@ -366,7 +366,7 @@ func TestStreamCloseAfterFINACK(t *testing.T) {
select {
case <-done:
t.Fatalf("Close should not have completed without processing FIN_ACK")
case <-time.After(2 * time.Second):
case <-time.After(200 * time.Millisecond):
}
b := make([]byte, 1)

78
p2p/transport/webrtc/transport_test.go

@ -106,6 +106,7 @@ func TestTransportWebRTC_CanListenSingle(t *testing.T) {
listener, err := tr.Listen(listenMultiaddr)
require.NoError(t, err)
defer listener.Close()
done := make(chan struct{})
go func() {
@ -146,6 +147,7 @@ func TestTransportWebRTC_CanListenMultiple(t *testing.T) {
listenMultiaddr := ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct")
listener, err := tr.Listen(listenMultiaddr)
require.NoError(t, err)
defer listener.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -156,6 +158,7 @@ func TestTransportWebRTC_CanListenMultiple(t *testing.T) {
conn, err := listener.Accept()
assert.NoError(t, err)
assert.NotNil(t, conn)
defer conn.Close()
}
wg.Wait()
cancel()
@ -172,6 +175,7 @@ func TestTransportWebRTC_CanListenMultiple(t *testing.T) {
default:
assert.NoError(t, err)
assert.NotNil(t, conn)
t.Cleanup(func() { conn.Close() })
}
}()
}
@ -188,21 +192,29 @@ func TestTransportWebRTC_CanCreateSuccessiveConnections(t *testing.T) {
listenMultiaddr := ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct")
listener, err := tr.Listen(listenMultiaddr)
require.NoError(t, err)
defer listener.Close()
count := 2
var wg sync.WaitGroup
wg.Add(count)
go func() {
for i := 0; i < count; i++ {
ctr, _ := getTransport(t)
conn, err := ctr.Dial(context.Background(), listener.Multiaddr(), listeningPeer)
require.NoError(t, err)
require.Equal(t, conn.RemotePeer(), listeningPeer)
t.Cleanup(func() { conn.Close() })
wg.Done()
}
}()
for i := 0; i < count; i++ {
_, err := listener.Accept()
conn, err := listener.Accept()
require.NoError(t, err)
defer conn.Close()
}
wg.Wait()
}
func TestTransportWebRTC_ListenerCanCreateStreams(t *testing.T) {
@ -211,12 +223,15 @@ func TestTransportWebRTC_ListenerCanCreateStreams(t *testing.T) {
listenMultiaddr := ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct")
listener, err := tr.Listen(listenMultiaddr)
require.NoError(t, err)
defer listener.Close()
streamChan := make(chan network.MuxedStream)
go func() {
conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer)
require.NoError(t, err)
t.Cleanup(func() { conn.Close() })
t.Logf("connection opened by dialer")
stream, err := conn.AcceptStream()
require.NoError(t, err)
t.Logf("dialer accepted stream")
@ -225,8 +240,9 @@ func TestTransportWebRTC_ListenerCanCreateStreams(t *testing.T) {
conn, err := listener.Accept()
require.NoError(t, err)
t.Logf("listener accepted connection")
defer conn.Close()
require.Equal(t, connectingPeer, conn.RemotePeer())
t.Logf("listener accepted connection")
stream, err := conn.OpenStream(context.Background())
require.NoError(t, err)
@ -252,6 +268,7 @@ func TestTransportWebRTC_DialerCanCreateStreams(t *testing.T) {
listenMultiaddr := ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct")
listener, err := tr.Listen(listenMultiaddr)
require.NoError(t, err)
defer listener.Close()
tr1, connectingPeer := getTransport(t)
done := make(chan struct{})
@ -260,6 +277,7 @@ func TestTransportWebRTC_DialerCanCreateStreams(t *testing.T) {
lconn, err := listener.Accept()
require.NoError(t, err)
require.Equal(t, connectingPeer, lconn.RemotePeer())
defer lconn.Close()
stream, err := lconn.AcceptStream()
require.NoError(t, err)
@ -268,19 +286,22 @@ func TestTransportWebRTC_DialerCanCreateStreams(t *testing.T) {
require.NoError(t, err)
require.Equal(t, "test", string(buf[:n]))
done <- struct{}{}
close(done)
}()
go func() {
conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer)
require.NoError(t, err)
defer conn.Close()
t.Logf("dialer opened connection")
stream, err := conn.OpenStream(context.Background())
require.NoError(t, err)
t.Logf("dialer opened stream")
_, err = stream.Write([]byte("test"))
require.NoError(t, err)
<-done
}()
select {
case <-done:
case <-time.After(10 * time.Second):
@ -293,6 +314,7 @@ func TestTransportWebRTC_DialerCanCreateStreamsMultiple(t *testing.T) {
listenMultiaddr := ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct")
listener, err := tr.Listen(listenMultiaddr)
require.NoError(t, err)
defer listener.Close()
tr1, connectingPeer := getTransport(t)
readerDone := make(chan struct{})
@ -308,6 +330,7 @@ func TestTransportWebRTC_DialerCanCreateStreamsMultiple(t *testing.T) {
lconn, err := listener.Accept()
require.NoError(t, err)
require.Equal(t, connectingPeer, lconn.RemotePeer())
defer lconn.Close()
var wg sync.WaitGroup
var doneStreams atomic.Int32
for i := 0; i < numListeners; i++ {
@ -334,6 +357,7 @@ func TestTransportWebRTC_DialerCanCreateStreamsMultiple(t *testing.T) {
conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer)
require.NoError(t, err)
defer conn.Close()
var writerWG sync.WaitGroup
var cnt atomic.Int32
@ -381,12 +405,14 @@ func TestTransportWebRTC_Deadline(t *testing.T) {
listenMultiaddr := ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct")
listener, err := tr.Listen(listenMultiaddr)
require.NoError(t, err)
defer listener.Close()
tr1, connectingPeer := getTransport(t)
t.Run("SetReadDeadline", func(t *testing.T) {
go func() {
lconn, err := listener.Accept()
require.NoError(t, err)
t.Cleanup(func() { lconn.Close() })
require.Equal(t, connectingPeer, lconn.RemotePeer())
_, err = lconn.AcceptStream()
require.NoError(t, err)
@ -394,6 +420,7 @@ func TestTransportWebRTC_Deadline(t *testing.T) {
conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer)
require.NoError(t, err)
defer conn.Close()
stream, err := conn.OpenStream(context.Background())
require.NoError(t, err)
@ -412,6 +439,7 @@ func TestTransportWebRTC_Deadline(t *testing.T) {
go func() {
lconn, err := listener.Accept()
require.NoError(t, err)
t.Cleanup(func() { lconn.Close() })
require.Equal(t, connectingPeer, lconn.RemotePeer())
_, err = lconn.AcceptStream()
require.NoError(t, err)
@ -419,6 +447,7 @@ func TestTransportWebRTC_Deadline(t *testing.T) {
conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer)
require.NoError(t, err)
defer conn.Close()
stream, err := conn.OpenStream(context.Background())
require.NoError(t, err)
@ -439,21 +468,30 @@ func TestTransportWebRTC_StreamWriteBufferContention(t *testing.T) {
listenMultiaddr := ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct")
listener, err := tr.Listen(listenMultiaddr)
require.NoError(t, err)
defer listener.Close()
tr1, connectingPeer := getTransport(t)
for i := 0; i < 2; i++ {
go func() {
lconn, err := listener.Accept()
require.NoError(t, err)
require.Equal(t, connectingPeer, lconn.RemotePeer())
_, err = lconn.AcceptStream()
require.NoError(t, err)
}()
}
var wg sync.WaitGroup
wg.Add(3)
go func() {
defer wg.Done()
lconn, err := listener.Accept()
require.NoError(t, err)
t.Cleanup(func() { lconn.Close() })
require.Equal(t, connectingPeer, lconn.RemotePeer())
for i := 0; i < 2; i++ {
go func() {
defer wg.Done()
_, err := lconn.AcceptStream()
require.NoError(t, err)
}()
}
}()
conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer)
require.NoError(t, err)
defer conn.Close()
errC := make(chan error)
// writers
@ -471,6 +509,7 @@ func TestTransportWebRTC_StreamWriteBufferContention(t *testing.T) {
require.ErrorIs(t, <-errC, os.ErrDeadlineExceeded)
require.ErrorIs(t, <-errC, os.ErrDeadlineExceeded)
wg.Wait()
}
func TestTransportWebRTC_RemoteReadsAfterClose(t *testing.T) {
@ -478,6 +517,7 @@ func TestTransportWebRTC_RemoteReadsAfterClose(t *testing.T) {
listenMultiaddr := ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct")
listener, err := tr.Listen(listenMultiaddr)
require.NoError(t, err)
defer listener.Close()
tr1, _ := getTransport(t)
@ -488,6 +528,8 @@ func TestTransportWebRTC_RemoteReadsAfterClose(t *testing.T) {
done <- err
return
}
t.Cleanup(func() { lconn.Close() })
stream, err := lconn.AcceptStream()
if err != nil {
done <- err
@ -508,6 +550,7 @@ func TestTransportWebRTC_RemoteReadsAfterClose(t *testing.T) {
conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer)
require.NoError(t, err)
defer conn.Close()
// create a stream
stream, err := conn.OpenStream(context.Background())
@ -527,6 +570,7 @@ func TestTransportWebRTC_RemoteReadsAfterClose2(t *testing.T) {
listenMultiaddr := ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct")
listener, err := tr.Listen(listenMultiaddr)
require.NoError(t, err)
defer listener.Close()
tr1, _ := getTransport(t)
@ -539,6 +583,7 @@ func TestTransportWebRTC_RemoteReadsAfterClose2(t *testing.T) {
done <- err
return
}
defer lconn.Close()
stream, err := lconn.AcceptStream()
if err != nil {
done <- err
@ -558,6 +603,7 @@ func TestTransportWebRTC_RemoteReadsAfterClose2(t *testing.T) {
conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer)
require.NoError(t, err)
defer conn.Close()
// create a stream
stream, err := conn.OpenStream(context.Background())
require.NoError(t, err)
@ -575,6 +621,7 @@ func TestTransportWebRTC_Close(t *testing.T) {
listenMultiaddr := ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct")
listener, err := tr.Listen(listenMultiaddr)
require.NoError(t, err)
defer listener.Close()
tr1, connectingPeer := getTransport(t)
@ -585,6 +632,7 @@ func TestTransportWebRTC_Close(t *testing.T) {
defer wg.Done()
lconn, err := listener.Accept()
require.NoError(t, err)
t.Cleanup(func() { lconn.Close() })
require.Equal(t, connectingPeer, lconn.RemotePeer())
stream, err := lconn.AcceptStream()
require.NoError(t, err)
@ -596,6 +644,7 @@ func TestTransportWebRTC_Close(t *testing.T) {
conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer)
require.NoError(t, err)
defer conn.Close()
stream, err := conn.OpenStream(context.Background())
require.NoError(t, err)
@ -662,6 +711,7 @@ func TestConnectionTimeoutOnListener(t *testing.T) {
addr = addr.Encapsulate(webrtcComponent)
conn, err := tr1.Dial(ctx, addr, listeningPeer)
require.NoError(t, err)
t.Cleanup(func() { conn.Close() })
str, err := conn.OpenStream(ctx)
require.NoError(t, err)
str.Write([]byte("foobar"))
@ -670,6 +720,7 @@ func TestConnectionTimeoutOnListener(t *testing.T) {
conn, err := ln.Accept()
require.NoError(t, err)
require.Equal(t, connectingPeer, conn.RemotePeer())
defer conn.Close()
str, err := conn.AcceptStream()
require.NoError(t, err)
@ -729,8 +780,9 @@ func TestMaxInFlightRequests(t *testing.T) {
dialer, _ := getTransport(t)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if _, err := dialer.Dial(ctx, ln.Multiaddr(), listeningPeer); err == nil {
if conn, err := dialer.Dial(ctx, ln.Multiaddr(), listeningPeer); err == nil {
success.Add(1)
t.Cleanup(func() { conn.Close() })
} else {
t.Log("failed to dial:", err)
fails.Add(1)

3
p2p/transport/webrtc/udpmux/mux.go

@ -255,6 +255,9 @@ func ufragFromSTUNMessage(msg *stun.Message) (string, error) {
return string(attr[index+1:]), nil
}
// RemoveConnByUfrag removes the connection associated with the ufrag and all the
// addresses associated with that connection. This method is called by pion when
// a peerconnection is closed.
func (mux *UDPMux) RemoveConnByUfrag(ufrag string) {
if ufrag == "" {
return

Loading…
Cancel
Save