From 773d8f010885ea6a875de1228cc8e69c2d0a6849 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 24 Oct 2020 16:50:12 +0700 Subject: [PATCH] pass a conn that can be type asserted to a net.UDPConn to quic-go --- p2p/transport/quic/listener.go | 4 +++- p2p/transport/quic/listener_test.go | 30 +++++++++++++++++++++++ p2p/transport/quic/reuse.go | 6 ++--- p2p/transport/quic/transport.go | 4 +++- p2p/transport/quic/transport_test.go | 36 +++++++++++++++++++++++++++- 5 files changed, 74 insertions(+), 6 deletions(-) diff --git a/p2p/transport/quic/listener.go b/p2p/transport/quic/listener.go index a44e4077c..70fe471c4 100644 --- a/p2p/transport/quic/listener.go +++ b/p2p/transport/quic/listener.go @@ -16,6 +16,8 @@ import ( ma "github.com/multiformats/go-multiaddr" ) +var quicListen = quic.Listen // so we can mock it in tests + // A listener listens for QUIC connections. type listener struct { quicListener quic.Listener @@ -38,7 +40,7 @@ func newListener(rconn *reuseConn, t *transport, localPeer peer.ID, key ic.PrivK conf, _ := identity.ConfigForAny() return conf, nil } - ln, err := quic.Listen(rconn, &tlsConf, t.serverConfig) + ln, err := quicListen(rconn, &tlsConf, t.serverConfig) if err != nil { return nil, err } diff --git a/p2p/transport/quic/listener_test.go b/p2p/transport/quic/listener_test.go index 3388f1ac3..cb5836ec3 100644 --- a/p2p/transport/quic/listener_test.go +++ b/p2p/transport/quic/listener_test.go @@ -3,18 +3,29 @@ package libp2pquic import ( "crypto/rand" "crypto/rsa" + "crypto/tls" "crypto/x509" + "errors" "fmt" "net" + "syscall" ic "github.com/libp2p/go-libp2p-core/crypto" tpt "github.com/libp2p/go-libp2p-core/transport" + quic "github.com/lucas-clemente/quic-go" ma "github.com/multiformats/go-multiaddr" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) +// interface containing some methods defined on the net.UDPConn, but not the net.PacketConn +type udpConn interface { + ReadFromUDP(b []byte) (int, *net.UDPAddr, error) + SetReadBuffer(bytes int) error + SyscallConn() (syscall.RawConn, error) +} + var _ = Describe("Listener", func() { var t tpt.Transport @@ -27,6 +38,25 @@ var _ = Describe("Listener", func() { Expect(err).ToNot(HaveOccurred()) }) + It("uses a conn that can interface assert to a UDPConn for listening", func() { + origQuicListen := quicListen + defer func() { quicListen = origQuicListen }() + + var conn net.PacketConn + quicListen = func(c net.PacketConn, _ *tls.Config, _ *quic.Config) (quic.Listener, error) { + conn = c + return nil, errors.New("listen error") + } + localAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") + Expect(err).ToNot(HaveOccurred()) + _, err = t.Listen(localAddr) + Expect(err).To(MatchError("listen error")) + Expect(conn).ToNot(BeNil()) + defer conn.Close() + _, ok := conn.(udpConn) + Expect(ok).To(BeTrue()) + }) + Context("listening on the right address", func() { It("returns the address it is listening on", func() { localAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") diff --git a/p2p/transport/quic/reuse.go b/p2p/transport/quic/reuse.go index e9582cd95..3aa694029 100644 --- a/p2p/transport/quic/reuse.go +++ b/p2p/transport/quic/reuse.go @@ -17,15 +17,15 @@ var ( ) type reuseConn struct { - net.PacketConn + *net.UDPConn mutex sync.Mutex refCount int unusedSince time.Time } -func newReuseConn(conn net.PacketConn, gater connmgr.ConnectionGater) *reuseConn { - return &reuseConn{PacketConn: conn} +func newReuseConn(conn *net.UDPConn, gater connmgr.ConnectionGater) *reuseConn { + return &reuseConn{UDPConn: conn} } func (c *reuseConn) IncreaseCount() { diff --git a/p2p/transport/quic/transport.go b/p2p/transport/quic/transport.go index c0d6055af..b28b8b0aa 100644 --- a/p2p/transport/quic/transport.go +++ b/p2p/transport/quic/transport.go @@ -27,6 +27,8 @@ import ( var log = logging.Logger("quic-transport") +var quicDialContext = quic.DialContext // so we can mock it in tests + var quicConfig = &quic.Config{ MaxIncomingStreams: 1000, MaxIncomingUniStreams: -1, // disable unidirectional streams @@ -157,7 +159,7 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp if err != nil { return nil, err } - sess, err := quic.DialContext(ctx, pconn, addr, host, tlsConf, t.clientConfig) + sess, err := quicDialContext(ctx, pconn, addr, host, tlsConf, t.clientConfig) if err != nil { pconn.DecreaseCount() return nil, err diff --git a/p2p/transport/quic/transport_test.go b/p2p/transport/quic/transport_test.go index 111b702c8..226ec2f9c 100644 --- a/p2p/transport/quic/transport_test.go +++ b/p2p/transport/quic/transport_test.go @@ -1,7 +1,17 @@ package libp2pquic import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "errors" + "net" + + ic "github.com/libp2p/go-libp2p-core/crypto" tpt "github.com/libp2p/go-libp2p-core/transport" + quic "github.com/lucas-clemente/quic-go" ma "github.com/multiformats/go-multiaddr" . "github.com/onsi/ginkgo" @@ -12,7 +22,12 @@ var _ = Describe("Transport", func() { var t tpt.Transport BeforeEach(func() { - t = &transport{} + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + Expect(err).ToNot(HaveOccurred()) + key, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(rsaKey)) + Expect(err).ToNot(HaveOccurred()) + t, err = NewTransport(key, nil, nil) + Expect(err).ToNot(HaveOccurred()) }) It("says if it can dial an address", func() { @@ -35,4 +50,23 @@ var _ = Describe("Transport", func() { Expect(protocols).To(HaveLen(1)) Expect(protocols[0]).To(Equal(ma.P_QUIC)) }) + + It("uses a conn that can interface assert to a UDPConn for dialing", func() { + origQuicDialContext := quicDialContext + defer func() { quicDialContext = origQuicDialContext }() + + var conn net.PacketConn + quicDialContext = func(_ context.Context, c net.PacketConn, _ net.Addr, _ string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { + conn = c + return nil, errors.New("listen error") + } + remoteAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") + Expect(err).ToNot(HaveOccurred()) + _, err = t.Dial(context.Background(), remoteAddr, "remote peer id") + Expect(err).To(MatchError("listen error")) + Expect(conn).ToNot(BeNil()) + defer conn.Close() + _, ok := conn.(udpConn) + Expect(ok).To(BeTrue()) + }) })