diff --git a/p2p/transport/webrtc/listener.go b/p2p/transport/webrtc/listener.go index 0e02d7a65..af2991ed8 100644 --- a/p2p/transport/webrtc/listener.go +++ b/p2p/transport/webrtc/listener.go @@ -57,6 +57,7 @@ type listener struct { // used to control the lifecycle of the listener ctx context.Context cancel context.CancelFunc + wg sync.WaitGroup } var _ tpt.Listener = &listener{} @@ -91,30 +92,27 @@ func newListener(transport *WebRTCTransport, laddr ma.Multiaddr, socket net.Pack } l.ctx, l.cancel = context.WithCancel(context.Background()) - mux := udpmux.NewUDPMux(socket) - l.mux = mux - mux.Start() + l.mux = udpmux.NewUDPMux(socket) + l.mux.Start() - go l.listen() + l.wg.Add(1) + go func() { + defer l.wg.Done() + l.listen() + }() return l, err } func (l *listener) listen() { - // Accepting a connection requires instantiating a peerconnection - // and a noise connection which is expensive. We therefore limit - // the number of in-flight connection requests. A connection - // is considered to be in flight from the instant it is handled - // until it is dequeued by a call to Accept, or errors out in some - // way. - inFlightQueueCh := make(chan struct{}, l.transport.maxInFlightConnections) - for i := uint32(0); i < l.transport.maxInFlightConnections; i++ { - inFlightQueueCh <- struct{}{} - } - + // Accepting a connection requires instantiating a peerconnection and a noise connection + // which is expensive. We therefore limit the number of in-flight connection requests. A + // connection is considered to be in flight from the instant it is handled until it is + // dequeued by a call to Accept, or errors out in some way. + inFlightSemaphore := make(chan struct{}, l.transport.maxInFlightConnections) for { select { - case <-inFlightQueueCh: + case inFlightSemaphore <- struct{}{}: case <-l.ctx.Done(): return } @@ -128,7 +126,7 @@ func (l *listener) listen() { } go func() { - defer func() { inFlightQueueCh <- struct{}{} }() // free this spot once again + defer func() { <-inFlightSemaphore }() ctx, cancel := context.WithTimeout(l.ctx, candidateSetupTimeout) defer cancel() @@ -145,7 +143,7 @@ func (l *listener) listen() { log.Warn("could not push connection: ctx done") conn.Close() case l.acceptQueue <- conn: - // acceptQueue is an unbuffered channel, so this block until the connection is accepted. + // acceptQueue is an unbuffered channel, so this blocks until the connection is accepted. } }() } @@ -307,7 +305,18 @@ func (l *listener) Close() error { select { case <-l.ctx.Done(): default: - l.cancel() + } + l.cancel() + l.mux.Close() + l.wg.Wait() +loop: + for { + select { + case conn := <-l.acceptQueue: + conn.Close() + default: + break loop + } } return nil } diff --git a/p2p/transport/webrtc/stream_test.go b/p2p/transport/webrtc/stream_test.go index 8f1ec165c..52b464c0e 100644 --- a/p2p/transport/webrtc/stream_test.go +++ b/p2p/transport/webrtc/stream_test.go @@ -366,7 +366,7 @@ func TestStreamCloseAfterFINACK(t *testing.T) { select { case <-done: t.Fatalf("Close should not have completed without processing FIN_ACK") - case <-time.After(2 * time.Second): + case <-time.After(200 * time.Millisecond): } b := make([]byte, 1) diff --git a/p2p/transport/webrtc/transport_test.go b/p2p/transport/webrtc/transport_test.go index 3bfcaaf5a..e5ef3ca1d 100644 --- a/p2p/transport/webrtc/transport_test.go +++ b/p2p/transport/webrtc/transport_test.go @@ -106,6 +106,7 @@ func TestTransportWebRTC_CanListenSingle(t *testing.T) { listener, err := tr.Listen(listenMultiaddr) require.NoError(t, err) + defer listener.Close() done := make(chan struct{}) go func() { @@ -146,6 +147,7 @@ func TestTransportWebRTC_CanListenMultiple(t *testing.T) { listenMultiaddr := ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct") listener, err := tr.Listen(listenMultiaddr) require.NoError(t, err) + defer listener.Close() ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -156,6 +158,7 @@ func TestTransportWebRTC_CanListenMultiple(t *testing.T) { conn, err := listener.Accept() assert.NoError(t, err) assert.NotNil(t, conn) + defer conn.Close() } wg.Wait() cancel() @@ -172,6 +175,7 @@ func TestTransportWebRTC_CanListenMultiple(t *testing.T) { default: assert.NoError(t, err) assert.NotNil(t, conn) + t.Cleanup(func() { conn.Close() }) } }() } @@ -188,21 +192,29 @@ func TestTransportWebRTC_CanCreateSuccessiveConnections(t *testing.T) { listenMultiaddr := ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct") listener, err := tr.Listen(listenMultiaddr) require.NoError(t, err) + defer listener.Close() + count := 2 + var wg sync.WaitGroup + wg.Add(count) go func() { for i := 0; i < count; i++ { ctr, _ := getTransport(t) conn, err := ctr.Dial(context.Background(), listener.Multiaddr(), listeningPeer) require.NoError(t, err) require.Equal(t, conn.RemotePeer(), listeningPeer) + t.Cleanup(func() { conn.Close() }) + wg.Done() } }() for i := 0; i < count; i++ { - _, err := listener.Accept() + conn, err := listener.Accept() require.NoError(t, err) + defer conn.Close() } + wg.Wait() } func TestTransportWebRTC_ListenerCanCreateStreams(t *testing.T) { @@ -211,12 +223,15 @@ func TestTransportWebRTC_ListenerCanCreateStreams(t *testing.T) { listenMultiaddr := ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct") listener, err := tr.Listen(listenMultiaddr) require.NoError(t, err) + defer listener.Close() streamChan := make(chan network.MuxedStream) go func() { conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer) require.NoError(t, err) + t.Cleanup(func() { conn.Close() }) t.Logf("connection opened by dialer") + stream, err := conn.AcceptStream() require.NoError(t, err) t.Logf("dialer accepted stream") @@ -225,8 +240,9 @@ func TestTransportWebRTC_ListenerCanCreateStreams(t *testing.T) { conn, err := listener.Accept() require.NoError(t, err) - t.Logf("listener accepted connection") + defer conn.Close() require.Equal(t, connectingPeer, conn.RemotePeer()) + t.Logf("listener accepted connection") stream, err := conn.OpenStream(context.Background()) require.NoError(t, err) @@ -252,6 +268,7 @@ func TestTransportWebRTC_DialerCanCreateStreams(t *testing.T) { listenMultiaddr := ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct") listener, err := tr.Listen(listenMultiaddr) require.NoError(t, err) + defer listener.Close() tr1, connectingPeer := getTransport(t) done := make(chan struct{}) @@ -260,6 +277,7 @@ func TestTransportWebRTC_DialerCanCreateStreams(t *testing.T) { lconn, err := listener.Accept() require.NoError(t, err) require.Equal(t, connectingPeer, lconn.RemotePeer()) + defer lconn.Close() stream, err := lconn.AcceptStream() require.NoError(t, err) @@ -268,19 +286,22 @@ func TestTransportWebRTC_DialerCanCreateStreams(t *testing.T) { require.NoError(t, err) require.Equal(t, "test", string(buf[:n])) - done <- struct{}{} + close(done) }() go func() { conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer) require.NoError(t, err) + defer conn.Close() t.Logf("dialer opened connection") stream, err := conn.OpenStream(context.Background()) require.NoError(t, err) t.Logf("dialer opened stream") _, err = stream.Write([]byte("test")) require.NoError(t, err) + <-done }() + select { case <-done: case <-time.After(10 * time.Second): @@ -293,6 +314,7 @@ func TestTransportWebRTC_DialerCanCreateStreamsMultiple(t *testing.T) { listenMultiaddr := ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct") listener, err := tr.Listen(listenMultiaddr) require.NoError(t, err) + defer listener.Close() tr1, connectingPeer := getTransport(t) readerDone := make(chan struct{}) @@ -308,6 +330,7 @@ func TestTransportWebRTC_DialerCanCreateStreamsMultiple(t *testing.T) { lconn, err := listener.Accept() require.NoError(t, err) require.Equal(t, connectingPeer, lconn.RemotePeer()) + defer lconn.Close() var wg sync.WaitGroup var doneStreams atomic.Int32 for i := 0; i < numListeners; i++ { @@ -334,6 +357,7 @@ func TestTransportWebRTC_DialerCanCreateStreamsMultiple(t *testing.T) { conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer) require.NoError(t, err) + defer conn.Close() var writerWG sync.WaitGroup var cnt atomic.Int32 @@ -381,12 +405,14 @@ func TestTransportWebRTC_Deadline(t *testing.T) { listenMultiaddr := ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct") listener, err := tr.Listen(listenMultiaddr) require.NoError(t, err) + defer listener.Close() tr1, connectingPeer := getTransport(t) t.Run("SetReadDeadline", func(t *testing.T) { go func() { lconn, err := listener.Accept() require.NoError(t, err) + t.Cleanup(func() { lconn.Close() }) require.Equal(t, connectingPeer, lconn.RemotePeer()) _, err = lconn.AcceptStream() require.NoError(t, err) @@ -394,6 +420,7 @@ func TestTransportWebRTC_Deadline(t *testing.T) { conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer) require.NoError(t, err) + defer conn.Close() stream, err := conn.OpenStream(context.Background()) require.NoError(t, err) @@ -412,6 +439,7 @@ func TestTransportWebRTC_Deadline(t *testing.T) { go func() { lconn, err := listener.Accept() require.NoError(t, err) + t.Cleanup(func() { lconn.Close() }) require.Equal(t, connectingPeer, lconn.RemotePeer()) _, err = lconn.AcceptStream() require.NoError(t, err) @@ -419,6 +447,7 @@ func TestTransportWebRTC_Deadline(t *testing.T) { conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer) require.NoError(t, err) + defer conn.Close() stream, err := conn.OpenStream(context.Background()) require.NoError(t, err) @@ -439,21 +468,30 @@ func TestTransportWebRTC_StreamWriteBufferContention(t *testing.T) { listenMultiaddr := ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct") listener, err := tr.Listen(listenMultiaddr) require.NoError(t, err) + defer listener.Close() tr1, connectingPeer := getTransport(t) - for i := 0; i < 2; i++ { - go func() { - lconn, err := listener.Accept() - require.NoError(t, err) - require.Equal(t, connectingPeer, lconn.RemotePeer()) - _, err = lconn.AcceptStream() - require.NoError(t, err) - }() - } + var wg sync.WaitGroup + wg.Add(3) + go func() { + defer wg.Done() + lconn, err := listener.Accept() + require.NoError(t, err) + t.Cleanup(func() { lconn.Close() }) + require.Equal(t, connectingPeer, lconn.RemotePeer()) + for i := 0; i < 2; i++ { + go func() { + defer wg.Done() + _, err := lconn.AcceptStream() + require.NoError(t, err) + }() + } + }() conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer) require.NoError(t, err) + defer conn.Close() errC := make(chan error) // writers @@ -471,6 +509,7 @@ func TestTransportWebRTC_StreamWriteBufferContention(t *testing.T) { require.ErrorIs(t, <-errC, os.ErrDeadlineExceeded) require.ErrorIs(t, <-errC, os.ErrDeadlineExceeded) + wg.Wait() } func TestTransportWebRTC_RemoteReadsAfterClose(t *testing.T) { @@ -478,6 +517,7 @@ func TestTransportWebRTC_RemoteReadsAfterClose(t *testing.T) { listenMultiaddr := ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct") listener, err := tr.Listen(listenMultiaddr) require.NoError(t, err) + defer listener.Close() tr1, _ := getTransport(t) @@ -488,6 +528,8 @@ func TestTransportWebRTC_RemoteReadsAfterClose(t *testing.T) { done <- err return } + t.Cleanup(func() { lconn.Close() }) + stream, err := lconn.AcceptStream() if err != nil { done <- err @@ -508,6 +550,7 @@ func TestTransportWebRTC_RemoteReadsAfterClose(t *testing.T) { conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer) require.NoError(t, err) + defer conn.Close() // create a stream stream, err := conn.OpenStream(context.Background()) @@ -527,6 +570,7 @@ func TestTransportWebRTC_RemoteReadsAfterClose2(t *testing.T) { listenMultiaddr := ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct") listener, err := tr.Listen(listenMultiaddr) require.NoError(t, err) + defer listener.Close() tr1, _ := getTransport(t) @@ -539,6 +583,7 @@ func TestTransportWebRTC_RemoteReadsAfterClose2(t *testing.T) { done <- err return } + defer lconn.Close() stream, err := lconn.AcceptStream() if err != nil { done <- err @@ -558,6 +603,7 @@ func TestTransportWebRTC_RemoteReadsAfterClose2(t *testing.T) { conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer) require.NoError(t, err) + defer conn.Close() // create a stream stream, err := conn.OpenStream(context.Background()) require.NoError(t, err) @@ -575,6 +621,7 @@ func TestTransportWebRTC_Close(t *testing.T) { listenMultiaddr := ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct") listener, err := tr.Listen(listenMultiaddr) require.NoError(t, err) + defer listener.Close() tr1, connectingPeer := getTransport(t) @@ -585,6 +632,7 @@ func TestTransportWebRTC_Close(t *testing.T) { defer wg.Done() lconn, err := listener.Accept() require.NoError(t, err) + t.Cleanup(func() { lconn.Close() }) require.Equal(t, connectingPeer, lconn.RemotePeer()) stream, err := lconn.AcceptStream() require.NoError(t, err) @@ -596,6 +644,7 @@ func TestTransportWebRTC_Close(t *testing.T) { conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer) require.NoError(t, err) + defer conn.Close() stream, err := conn.OpenStream(context.Background()) require.NoError(t, err) @@ -662,6 +711,7 @@ func TestConnectionTimeoutOnListener(t *testing.T) { addr = addr.Encapsulate(webrtcComponent) conn, err := tr1.Dial(ctx, addr, listeningPeer) require.NoError(t, err) + t.Cleanup(func() { conn.Close() }) str, err := conn.OpenStream(ctx) require.NoError(t, err) str.Write([]byte("foobar")) @@ -670,6 +720,7 @@ func TestConnectionTimeoutOnListener(t *testing.T) { conn, err := ln.Accept() require.NoError(t, err) require.Equal(t, connectingPeer, conn.RemotePeer()) + defer conn.Close() str, err := conn.AcceptStream() require.NoError(t, err) @@ -729,8 +780,9 @@ func TestMaxInFlightRequests(t *testing.T) { dialer, _ := getTransport(t) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - if _, err := dialer.Dial(ctx, ln.Multiaddr(), listeningPeer); err == nil { + if conn, err := dialer.Dial(ctx, ln.Multiaddr(), listeningPeer); err == nil { success.Add(1) + t.Cleanup(func() { conn.Close() }) } else { t.Log("failed to dial:", err) fails.Add(1) diff --git a/p2p/transport/webrtc/udpmux/mux.go b/p2p/transport/webrtc/udpmux/mux.go index 30cc61eda..98e60e3f1 100644 --- a/p2p/transport/webrtc/udpmux/mux.go +++ b/p2p/transport/webrtc/udpmux/mux.go @@ -255,6 +255,9 @@ func ufragFromSTUNMessage(msg *stun.Message) (string, error) { return string(attr[index+1:]), nil } +// RemoveConnByUfrag removes the connection associated with the ufrag and all the +// addresses associated with that connection. This method is called by pion when +// a peerconnection is closed. func (mux *UDPMux) RemoveConnByUfrag(ufrag string) { if ufrag == "" { return