Browse Source

implement a Transport.Close that waits for the reuse's GC to finish

pull/1424/head
Marten Seemann 3 years ago
parent
commit
2cef5be872
  1. 20
      p2p/transport/quic/conn_test.go
  2. 13
      p2p/transport/quic/libp2pquic_suite_test.go
  3. 7
      p2p/transport/quic/listener_test.go
  4. 72
      p2p/transport/quic/reuse.go
  5. 14
      p2p/transport/quic/reuse_test.go
  6. 11
      p2p/transport/quic/transport.go
  7. 5
      p2p/transport/quic/transport_test.go

20
p2p/transport/quic/conn_test.go

@ -5,6 +5,7 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
mrand "math/rand" mrand "math/rand"
"net" "net"
@ -70,11 +71,13 @@ var _ = Describe("Connection", func() {
It("handshakes on IPv4", func() { It("handshakes on IPv4", func() {
serverTransport, err := NewTransport(serverKey, nil, nil) serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln.Close() defer ln.Close()
clientTransport, err := NewTransport(clientKey, nil, nil) clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer conn.Close() defer conn.Close()
@ -94,11 +97,13 @@ var _ = Describe("Connection", func() {
It("handshakes on IPv6", func() { It("handshakes on IPv6", func() {
serverTransport, err := NewTransport(serverKey, nil, nil) serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip6/::1/udp/0/quic") ln := runServer(serverTransport, "/ip6/::1/udp/0/quic")
defer ln.Close() defer ln.Close()
clientTransport, err := NewTransport(clientKey, nil, nil) clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer conn.Close() defer conn.Close()
@ -118,11 +123,13 @@ var _ = Describe("Connection", func() {
It("opens and accepts streams", func() { It("opens and accepts streams", func() {
serverTransport, err := NewTransport(serverKey, nil, nil) serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln.Close() defer ln.Close()
clientTransport, err := NewTransport(clientKey, nil, nil) clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer conn.Close() defer conn.Close()
@ -147,6 +154,7 @@ var _ = Describe("Connection", func() {
serverTransport, err := NewTransport(serverKey, nil, nil) serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
clientTransport, err := NewTransport(clientKey, nil, nil) clientTransport, err := NewTransport(clientKey, nil, nil)
@ -154,6 +162,7 @@ var _ = Describe("Connection", func() {
// dial, but expect the wrong peer ID // dial, but expect the wrong peer ID
_, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), thirdPartyID) _, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), thirdPartyID)
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
defer clientTransport.(io.Closer).Close()
Expect(err.Error()).To(ContainSubstring("CRYPTO_ERROR")) Expect(err.Error()).To(ContainSubstring("CRYPTO_ERROR"))
done := make(chan struct{}) done := make(chan struct{})
@ -172,6 +181,7 @@ var _ = Describe("Connection", func() {
cg.EXPECT().InterceptAccept(gomock.Any()) cg.EXPECT().InterceptAccept(gomock.Any())
serverTransport, err := NewTransport(serverKey, nil, cg) serverTransport, err := NewTransport(serverKey, nil, cg)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln.Close() defer ln.Close()
@ -185,6 +195,7 @@ var _ = Describe("Connection", func() {
clientTransport, err := NewTransport(clientKey, nil, nil) clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
// make sure that connection attempts fails // make sure that connection attempts fails
conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -205,6 +216,7 @@ var _ = Describe("Connection", func() {
It("gates secured connections", func() { It("gates secured connections", func() {
serverTransport, err := NewTransport(serverKey, nil, nil) serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln.Close() defer ln.Close()
@ -213,6 +225,7 @@ var _ = Describe("Connection", func() {
clientTransport, err := NewTransport(clientKey, nil, cg) clientTransport, err := NewTransport(clientKey, nil, cg)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
// make sure that connection attempts fails // make sure that connection attempts fails
_, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) _, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID)
@ -232,10 +245,12 @@ var _ = Describe("Connection", func() {
serverTransport, err := NewTransport(serverKey, nil, nil) serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln1 := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") ln1 := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
defer ln1.Close() defer ln1.Close()
serverTransport2, err := NewTransport(serverKey2, nil, nil) serverTransport2, err := NewTransport(serverKey2, nil, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer serverTransport2.(io.Closer).Close()
ln2 := runServer(serverTransport2, "/ip4/127.0.0.1/udp/0/quic") ln2 := runServer(serverTransport2, "/ip4/127.0.0.1/udp/0/quic")
defer ln2.Close() defer ln2.Close()
@ -262,6 +277,7 @@ var _ = Describe("Connection", func() {
clientTransport, err := NewTransport(clientKey, nil, nil) clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
c1, err := clientTransport.Dial(context.Background(), ln1.Multiaddr(), serverID) c1, err := clientTransport.Dial(context.Background(), ln1.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer c1.Close() defer c1.Close()
@ -291,6 +307,7 @@ var _ = Describe("Connection", func() {
It("sends stateless resets", func() { It("sends stateless resets", func() {
serverTransport, err := NewTransport(serverKey, nil, nil) serverTransport, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer serverTransport.(io.Closer).Close()
ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic")
var drop uint32 var drop uint32
@ -307,6 +324,7 @@ var _ = Describe("Connection", func() {
// establish a connection // establish a connection
clientTransport, err := NewTransport(clientKey, nil, nil) clientTransport, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer clientTransport.(io.Closer).Close()
proxyAddr, err := toQuicMultiaddr(proxy.LocalAddr()) proxyAddr, err := toQuicMultiaddr(proxy.LocalAddr())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
conn, err := clientTransport.Dial(context.Background(), proxyAddr, serverID) conn, err := clientTransport.Dial(context.Background(), proxyAddr, serverID)
@ -349,6 +367,7 @@ var _ = Describe("Connection", func() {
It("hole punches", func() { It("hole punches", func() {
t1, err := NewTransport(serverKey, nil, nil) t1, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer t1.(io.Closer).Close()
laddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") laddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic")
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
ln1, err := t1.Listen(laddr) ln1, err := t1.Listen(laddr)
@ -364,6 +383,7 @@ var _ = Describe("Connection", func() {
t2, err := NewTransport(clientKey, nil, nil) t2, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
defer t2.(io.Closer).Close()
ln2, err := t2.Listen(laddr) ln2, err := t2.Listen(laddr)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
done2 := make(chan struct{}) done2 := make(chan struct{})

13
p2p/transport/quic/libp2pquic_suite_test.go

@ -1,14 +1,11 @@
package libp2pquic package libp2pquic
import ( import (
"bytes"
mrand "math/rand" mrand "math/rand"
"runtime/pprof"
"strings"
"testing" "testing"
"time" "time"
gomock "github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
@ -31,16 +28,9 @@ var (
mockCtrl *gomock.Controller mockCtrl *gomock.Controller
) )
func isGarbageCollectorRunning() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "go-libp2p-quic-transport.(*reuse).runGarbageCollector")
}
var _ = BeforeEach(func() { var _ = BeforeEach(func() {
mockCtrl = gomock.NewController(GinkgoT()) mockCtrl = gomock.NewController(GinkgoT())
Expect(isGarbageCollectorRunning()).To(BeFalse())
garbageCollectIntervalOrig = garbageCollectInterval garbageCollectIntervalOrig = garbageCollectInterval
maxUnusedDurationOrig = maxUnusedDuration maxUnusedDurationOrig = maxUnusedDuration
garbageCollectInterval = 50 * time.Millisecond garbageCollectInterval = 50 * time.Millisecond
@ -52,7 +42,6 @@ var _ = BeforeEach(func() {
var _ = AfterEach(func() { var _ = AfterEach(func() {
mockCtrl.Finish() mockCtrl.Finish()
Eventually(isGarbageCollectorRunning).Should(BeFalse())
garbageCollectInterval = garbageCollectIntervalOrig garbageCollectInterval = garbageCollectIntervalOrig
maxUnusedDuration = maxUnusedDurationOrig maxUnusedDuration = maxUnusedDurationOrig
quicConfig = origQuicConfig quicConfig = origQuicConfig

7
p2p/transport/quic/listener_test.go

@ -7,12 +7,13 @@ import (
"crypto/x509" "crypto/x509"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"syscall" "syscall"
ic "github.com/libp2p/go-libp2p-core/crypto" ic "github.com/libp2p/go-libp2p-core/crypto"
tpt "github.com/libp2p/go-libp2p-core/transport" tpt "github.com/libp2p/go-libp2p-core/transport"
quic "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
@ -38,6 +39,10 @@ var _ = Describe("Listener", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
AfterEach(func() {
Expect(t.(io.Closer).Close()).To(Succeed())
})
It("uses a conn that can interface assert to a UDPConn for listening", func() { It("uses a conn that can interface assert to a UDPConn for listening", func() {
origQuicListen := quicListen origQuicListen := quicListen
defer func() { quicListen = origQuicListen }() defer func() { quicListen = origQuicListen }()

72
p2p/transport/quic/reuse.go

@ -53,6 +53,9 @@ type reuse struct {
garbageCollectorRunning bool garbageCollectorRunning bool
closeChan chan struct{}
garbageCollectorStopChan chan struct{}
unicast map[string] /* IP.String() */ map[int] /* port */ *reuseConn unicast map[string] /* IP.String() */ map[int] /* port */ *reuseConn
// global contains connections that are listening on 0.0.0.0 / :: // global contains connections that are listening on 0.0.0.0 / ::
global map[int]*reuseConn global map[int]*reuseConn
@ -60,45 +63,52 @@ type reuse struct {
func newReuse() *reuse { func newReuse() *reuse {
return &reuse{ return &reuse{
unicast: make(map[string]map[int]*reuseConn), unicast: make(map[string]map[int]*reuseConn),
global: make(map[int]*reuseConn), global: make(map[int]*reuseConn),
closeChan: make(chan struct{}),
} }
} }
func (r *reuse) runGarbageCollector() { func (r *reuse) runGarbageCollector() {
defer close(r.garbageCollectorStopChan)
ticker := time.NewTicker(garbageCollectInterval) ticker := time.NewTicker(garbageCollectInterval)
defer ticker.Stop() defer ticker.Stop()
for now := range ticker.C { for {
var shouldExit bool select {
r.mutex.Lock() case <-r.closeChan:
for key, conn := range r.global { return
if conn.ShouldGarbageCollect(now) { case now := <-ticker.C:
conn.Close() var shouldExit bool
delete(r.global, key) r.mutex.Lock()
} for key, conn := range r.global {
}
for ukey, conns := range r.unicast {
for key, conn := range conns {
if conn.ShouldGarbageCollect(now) { if conn.ShouldGarbageCollect(now) {
conn.Close() conn.Close()
delete(conns, key) delete(r.global, key)
} }
} }
if len(conns) == 0 { for ukey, conns := range r.unicast {
delete(r.unicast, ukey) for key, conn := range conns {
if conn.ShouldGarbageCollect(now) {
conn.Close()
delete(conns, key)
}
}
if len(conns) == 0 {
delete(r.unicast, ukey)
}
} }
}
// stop the garbage collector if we're not tracking any connections // stop the garbage collector if we're not tracking any connections
if len(r.global) == 0 && len(r.unicast) == 0 { if len(r.global) == 0 && len(r.unicast) == 0 {
r.garbageCollectorRunning = false r.garbageCollectorRunning = false
shouldExit = true shouldExit = true
} }
r.mutex.Unlock() r.mutex.Unlock()
if shouldExit { if shouldExit {
return return
}
} }
} }
} }
@ -107,6 +117,7 @@ func (r *reuse) runGarbageCollector() {
func (r *reuse) maybeStartGarbageCollector() { func (r *reuse) maybeStartGarbageCollector() {
if !r.garbageCollectorRunning { if !r.garbageCollectorRunning {
r.garbageCollectorRunning = true r.garbageCollectorRunning = true
r.garbageCollectorStopChan = make(chan struct{})
go r.runGarbageCollector() go r.runGarbageCollector()
} }
} }
@ -199,3 +210,14 @@ func (r *reuse) Listen(network string, laddr *net.UDPAddr) (*reuseConn, error) {
r.unicast[localAddr.IP.String()][localAddr.Port] = rconn r.unicast[localAddr.IP.String()][localAddr.Port] = rconn
return rconn, err return rconn, err
} }
func (r *reuse) Close() error {
r.mutex.Lock()
defer r.mutex.Unlock()
close(r.closeChan)
if r.garbageCollectorRunning {
<-r.garbageCollectorStopChan
r.garbageCollectorRunning = false
}
return nil
}

14
p2p/transport/quic/reuse_test.go

@ -1,7 +1,10 @@
package libp2pquic package libp2pquic
import ( import (
"bytes"
"net" "net"
"runtime/pprof"
"strings"
"time" "time"
"github.com/libp2p/go-netroute" "github.com/libp2p/go-netroute"
@ -30,7 +33,6 @@ func closeAllConns(reuse *reuse) {
} }
} }
reuse.mutex.Unlock() reuse.mutex.Unlock()
Eventually(isGarbageCollectorRunning).Should(BeFalse())
} }
func OnPlatformsWithRoutingTablesIt(description string, f interface{}) { func OnPlatformsWithRoutingTablesIt(description string, f interface{}) {
@ -48,6 +50,16 @@ var _ = Describe("Reuse", func() {
reuse = newReuse() reuse = newReuse()
}) })
AfterEach(func() {
Expect(reuse.Close()).To(Succeed())
})
isGarbageCollectorRunning := func() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "go-libp2p-quic-transport.(*reuse).runGarbageCollector")
}
Context("creating and reusing connections", func() { Context("creating and reusing connections", func() {
AfterEach(func() { closeAllConns(reuse) }) AfterEach(func() { closeAllConns(reuse) })

11
p2p/transport/quic/transport.go

@ -93,6 +93,13 @@ func (c *connManager) Dial(network string, raddr *net.UDPAddr) (*reuseConn, erro
return reuse.Dial(network, raddr) return reuse.Dial(network, raddr)
} }
func (c *connManager) Close() error {
if err := c.reuseUDP6.Close(); err != nil {
return err
}
return c.reuseUDP4.Close()
}
// The Transport implements the tpt.Transport interface for QUIC connections. // The Transport implements the tpt.Transport interface for QUIC connections.
type transport struct { type transport struct {
privKey ic.PrivKey privKey ic.PrivKey
@ -346,3 +353,7 @@ func (t *transport) Protocols() []int {
func (t *transport) String() string { func (t *transport) String() string {
return "QUIC" return "QUIC"
} }
func (t *transport) Close() error {
return t.connManager.Close()
}

5
p2p/transport/quic/transport_test.go

@ -7,6 +7,7 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors" "errors"
"io"
"net" "net"
ic "github.com/libp2p/go-libp2p-core/crypto" ic "github.com/libp2p/go-libp2p-core/crypto"
@ -30,6 +31,10 @@ var _ = Describe("Transport", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
}) })
AfterEach(func() {
Expect(t.(io.Closer).Close()).To(Succeed())
})
It("says if it can dial an address", func() { It("says if it can dial an address", func() {
invalidAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/1234") invalidAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/1234")
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

Loading…
Cancel
Save