Browse Source

add hole punching support (#194)

pull/1424/head
vyzo 3 years ago
committed by GitHub
parent
commit
4ff2d36c67
  1. 53
      p2p/transport/quic/conn_test.go
  2. 18
      p2p/transport/quic/listener.go
  3. 110
      p2p/transport/quic/transport.go
  4. 2
      p2p/transport/quic/transport_test.go

53
p2p/transport/quic/conn_test.go

@ -11,14 +11,14 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
gomock "github.com/golang/mock/gomock"
ic "github.com/libp2p/go-libp2p-core/crypto" ic "github.com/libp2p/go-libp2p-core/crypto"
n "github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peer"
tpt "github.com/libp2p/go-libp2p-core/transport" tpt "github.com/libp2p/go-libp2p-core/transport"
quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy" quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
"github.com/golang/mock/gomock"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
@ -345,4 +345,53 @@ var _ = Describe("Connection", func() {
Expect(rerr).To(HaveOccurred()) Expect(rerr).To(HaveOccurred())
Expect(rerr.Error()).To(ContainSubstring("received a stateless reset")) Expect(rerr.Error()).To(ContainSubstring("received a stateless reset"))
}) })
It("hole punches", func() {
t1, err := NewTransport(serverKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
laddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic")
Expect(err).ToNot(HaveOccurred())
ln1, err := t1.Listen(laddr)
Expect(err).ToNot(HaveOccurred())
done1 := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done1)
if _, err := ln1.Accept(); err == nil {
Fail("didn't expect to accept any connections")
}
}()
t2, err := NewTransport(clientKey, nil, nil)
Expect(err).ToNot(HaveOccurred())
ln2, err := t2.Listen(laddr)
Expect(err).ToNot(HaveOccurred())
done2 := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done2)
if _, err := ln2.Accept(); err == nil {
Fail("didn't expect to accept any connections")
}
}()
connChan := make(chan tpt.CapableConn)
go func() {
defer GinkgoRecover()
conn, err := t2.Dial(n.WithSimultaneousConnect(context.Background(), ""), ln1.Multiaddr(), serverID)
Expect(err).ToNot(HaveOccurred())
connChan <- conn
}()
conn1, err := t1.Dial(n.WithSimultaneousConnect(context.Background(), ""), ln2.Multiaddr(), clientID)
Expect(err).ToNot(HaveOccurred())
defer conn1.Close()
Expect(conn1.RemotePeer()).To(Equal(clientID))
var conn2 tpt.CapableConn
Eventually(connChan).Should(Receive(&conn2))
defer conn2.Close()
Expect(conn2.RemotePeer()).To(Equal(serverID))
ln1.Close()
ln2.Close()
Eventually(done1).Should(BeClosed())
Eventually(done2).Should(BeClosed())
})
}) })

18
p2p/transport/quic/listener.go

@ -12,7 +12,7 @@ import (
p2ptls "github.com/libp2p/go-libp2p-tls" p2ptls "github.com/libp2p/go-libp2p-tls"
quic "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
) )
@ -74,6 +74,21 @@ func (l *listener) Accept() (tpt.CapableConn, error) {
sess.CloseWithError(errorCodeConnectionGating, "connection gated") sess.CloseWithError(errorCodeConnectionGating, "connection gated")
continue continue
} }
// return through active hole punching if any
key := holePunchKey{addr: sess.RemoteAddr().String(), peer: conn.remotePeerID}
var wasHolePunch bool
l.transport.holePunchingMx.Lock()
holePunch, ok := l.transport.holePunching[key]
if ok && !holePunch.fulfilled {
holePunch.connCh <- conn
wasHolePunch = true
holePunch.fulfilled = true
}
l.transport.holePunchingMx.Unlock()
if wasHolePunch {
continue
}
return conn, nil return conn, nil
} }
} }
@ -92,6 +107,7 @@ func (l *listener) setupConn(sess quic.Session) (*conn, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
remoteMultiaddr, err := toQuicMultiaddr(sess.RemoteAddr()) remoteMultiaddr, err := toQuicMultiaddr(sess.RemoteAddr())
if err != nil { if err != nil {
return nil, err return nil, err

110
p2p/transport/quic/transport.go

@ -1,20 +1,23 @@
package libp2pquic package libp2pquic
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"math/rand"
"net" "net"
"sync"
"github.com/libp2p/go-libp2p-core/connmgr" "time"
n "github.com/libp2p/go-libp2p-core/network"
"github.com/minio/sha256-simd" "github.com/minio/sha256-simd"
"golang.org/x/crypto/hkdf" "golang.org/x/crypto/hkdf"
logging "github.com/ipfs/go-log" logging "github.com/ipfs/go-log"
"github.com/libp2p/go-libp2p-core/connmgr"
ic "github.com/libp2p/go-libp2p-core/crypto" ic "github.com/libp2p/go-libp2p-core/crypto"
n "github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/pnet" "github.com/libp2p/go-libp2p-core/pnet"
tpt "github.com/libp2p/go-libp2p-core/transport" tpt "github.com/libp2p/go-libp2p-core/transport"
@ -27,8 +30,12 @@ import (
var log = logging.Logger("quic-transport") var log = logging.Logger("quic-transport")
var ErrHolePunching = errors.New("hole punching attempted; no active dial")
var quicDialContext = quic.DialContext // so we can mock it in tests var quicDialContext = quic.DialContext // so we can mock it in tests
var HolePunchTimeout = 5 * time.Second
var quicConfig = &quic.Config{ var quicConfig = &quic.Config{
MaxIncomingStreams: 1000, MaxIncomingStreams: 1000,
MaxIncomingUniStreams: -1, // disable unidirectional streams MaxIncomingUniStreams: -1, // disable unidirectional streams
@ -96,10 +103,23 @@ type transport struct {
serverConfig *quic.Config serverConfig *quic.Config
clientConfig *quic.Config clientConfig *quic.Config
gater connmgr.ConnectionGater gater connmgr.ConnectionGater
holePunchingMx sync.Mutex
holePunching map[holePunchKey]*activeHolePunch
} }
var _ tpt.Transport = &transport{} var _ tpt.Transport = &transport{}
type holePunchKey struct {
addr string
peer peer.ID
}
type activeHolePunch struct {
connCh chan tpt.CapableConn
fulfilled bool
}
// NewTransport creates a new QUIC transport // NewTransport creates a new QUIC transport
func NewTransport(key ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater) (tpt.Transport, error) { func NewTransport(key ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater) (tpt.Transport, error) {
if len(psk) > 0 { if len(psk) > 0 {
@ -138,6 +158,7 @@ func NewTransport(key ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater) (
serverConfig: config, serverConfig: config,
clientConfig: config.Clone(), clientConfig: config.Clone(),
gater: gater, gater: gater,
holePunching: make(map[holePunchKey]*activeHolePunch),
}, nil }, nil
} }
@ -156,6 +177,13 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp
return nil, err return nil, err
} }
tlsConf, keyCh := t.identity.ConfigForPeer(p) tlsConf, keyCh := t.identity.ConfigForPeer(p)
if simConnect, _ := n.GetSimultaneousConnect(ctx); simConnect {
if bytes.Compare([]byte(t.localPeer), []byte(p)) < 0 {
return t.holePunch(ctx, network, addr, p)
}
}
pconn, err := t.connManager.Dial(network, addr) pconn, err := t.connManager.Dial(network, addr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -202,6 +230,82 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp
return conn, nil return conn, nil
} }
func (t *transport) holePunch(ctx context.Context, network string, addr *net.UDPAddr, p peer.ID) (tpt.CapableConn, error) {
pconn, err := t.connManager.Dial(network, addr)
if err != nil {
return nil, err
}
defer pconn.DecreaseCount()
ctx, cancel := context.WithTimeout(ctx, HolePunchTimeout)
defer cancel()
key := holePunchKey{addr: addr.String(), peer: p}
t.holePunchingMx.Lock()
if _, ok := t.holePunching[key]; ok {
t.holePunchingMx.Unlock()
return nil, fmt.Errorf("already punching hole for %s", addr)
}
connCh := make(chan tpt.CapableConn, 1)
t.holePunching[key] = &activeHolePunch{connCh: connCh}
t.holePunchingMx.Unlock()
var timer *time.Timer
defer func() {
if timer != nil {
timer.Stop()
}
}()
payload := make([]byte, 64)
var punchErr error
loop:
for i := 0; ; i++ {
if _, err := rand.Read(payload); err != nil {
punchErr = err
break
}
if _, err := pconn.UDPConn.WriteToUDP(payload, addr); err != nil {
punchErr = err
break
}
maxSleep := 10 * (i + 1) * (i + 1) // in ms
if maxSleep > 200 {
maxSleep = 200
}
d := 10*time.Millisecond + time.Duration(rand.Intn(maxSleep))*time.Millisecond
if timer == nil {
timer = time.NewTimer(d)
} else {
timer.Reset(d)
}
select {
case c := <-connCh:
t.holePunchingMx.Lock()
delete(t.holePunching, key)
t.holePunchingMx.Unlock()
return c, nil
case <-timer.C:
case <-ctx.Done():
punchErr = ErrHolePunching
break loop
}
}
// we only arrive here if punchErr != nil
t.holePunchingMx.Lock()
defer func() {
delete(t.holePunching, key)
t.holePunchingMx.Unlock()
}()
select {
case c := <-t.holePunching[key].connCh:
return c, nil
default:
return nil, punchErr
}
}
// Don't use mafmt.QUIC as we don't want to dial DNS addresses. Just /ip{4,6}/udp/quic // Don't use mafmt.QUIC as we don't want to dial DNS addresses. Just /ip{4,6}/udp/quic
var dialMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_UDP), mafmt.Base(ma.P_QUIC)) var dialMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_UDP), mafmt.Base(ma.P_QUIC))

2
p2p/transport/quic/transport_test.go

@ -11,7 +11,7 @@ import (
ic "github.com/libp2p/go-libp2p-core/crypto" ic "github.com/libp2p/go-libp2p-core/crypto"
tpt "github.com/libp2p/go-libp2p-core/transport" tpt "github.com/libp2p/go-libp2p-core/transport"
quic "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"

Loading…
Cancel
Save