|
|
@ -1,20 +1,23 @@ |
|
|
|
package libp2pquic |
|
|
|
|
|
|
|
import ( |
|
|
|
"bytes" |
|
|
|
"context" |
|
|
|
"errors" |
|
|
|
"fmt" |
|
|
|
"io" |
|
|
|
"math/rand" |
|
|
|
"net" |
|
|
|
|
|
|
|
"github.com/libp2p/go-libp2p-core/connmgr" |
|
|
|
n "github.com/libp2p/go-libp2p-core/network" |
|
|
|
"sync" |
|
|
|
"time" |
|
|
|
|
|
|
|
"github.com/minio/sha256-simd" |
|
|
|
"golang.org/x/crypto/hkdf" |
|
|
|
|
|
|
|
logging "github.com/ipfs/go-log" |
|
|
|
"github.com/libp2p/go-libp2p-core/connmgr" |
|
|
|
ic "github.com/libp2p/go-libp2p-core/crypto" |
|
|
|
n "github.com/libp2p/go-libp2p-core/network" |
|
|
|
"github.com/libp2p/go-libp2p-core/peer" |
|
|
|
"github.com/libp2p/go-libp2p-core/pnet" |
|
|
|
tpt "github.com/libp2p/go-libp2p-core/transport" |
|
|
@ -27,8 +30,12 @@ import ( |
|
|
|
|
|
|
|
var log = logging.Logger("quic-transport") |
|
|
|
|
|
|
|
var ErrHolePunching = errors.New("hole punching attempted; no active dial") |
|
|
|
|
|
|
|
var quicDialContext = quic.DialContext // so we can mock it in tests
|
|
|
|
|
|
|
|
var HolePunchTimeout = 5 * time.Second |
|
|
|
|
|
|
|
var quicConfig = &quic.Config{ |
|
|
|
MaxIncomingStreams: 1000, |
|
|
|
MaxIncomingUniStreams: -1, // disable unidirectional streams
|
|
|
@ -96,10 +103,23 @@ type transport struct { |
|
|
|
serverConfig *quic.Config |
|
|
|
clientConfig *quic.Config |
|
|
|
gater connmgr.ConnectionGater |
|
|
|
|
|
|
|
holePunchingMx sync.Mutex |
|
|
|
holePunching map[holePunchKey]*activeHolePunch |
|
|
|
} |
|
|
|
|
|
|
|
var _ tpt.Transport = &transport{} |
|
|
|
|
|
|
|
type holePunchKey struct { |
|
|
|
addr string |
|
|
|
peer peer.ID |
|
|
|
} |
|
|
|
|
|
|
|
type activeHolePunch struct { |
|
|
|
connCh chan tpt.CapableConn |
|
|
|
fulfilled bool |
|
|
|
} |
|
|
|
|
|
|
|
// NewTransport creates a new QUIC transport
|
|
|
|
func NewTransport(key ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater) (tpt.Transport, error) { |
|
|
|
if len(psk) > 0 { |
|
|
@ -138,6 +158,7 @@ func NewTransport(key ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater) ( |
|
|
|
serverConfig: config, |
|
|
|
clientConfig: config.Clone(), |
|
|
|
gater: gater, |
|
|
|
holePunching: make(map[holePunchKey]*activeHolePunch), |
|
|
|
}, nil |
|
|
|
} |
|
|
|
|
|
|
@ -156,6 +177,13 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
tlsConf, keyCh := t.identity.ConfigForPeer(p) |
|
|
|
|
|
|
|
if simConnect, _ := n.GetSimultaneousConnect(ctx); simConnect { |
|
|
|
if bytes.Compare([]byte(t.localPeer), []byte(p)) < 0 { |
|
|
|
return t.holePunch(ctx, network, addr, p) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
pconn, err := t.connManager.Dial(network, addr) |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
@ -202,6 +230,82 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp |
|
|
|
return conn, nil |
|
|
|
} |
|
|
|
|
|
|
|
func (t *transport) holePunch(ctx context.Context, network string, addr *net.UDPAddr, p peer.ID) (tpt.CapableConn, error) { |
|
|
|
pconn, err := t.connManager.Dial(network, addr) |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
defer pconn.DecreaseCount() |
|
|
|
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, HolePunchTimeout) |
|
|
|
defer cancel() |
|
|
|
|
|
|
|
key := holePunchKey{addr: addr.String(), peer: p} |
|
|
|
t.holePunchingMx.Lock() |
|
|
|
if _, ok := t.holePunching[key]; ok { |
|
|
|
t.holePunchingMx.Unlock() |
|
|
|
return nil, fmt.Errorf("already punching hole for %s", addr) |
|
|
|
} |
|
|
|
connCh := make(chan tpt.CapableConn, 1) |
|
|
|
t.holePunching[key] = &activeHolePunch{connCh: connCh} |
|
|
|
t.holePunchingMx.Unlock() |
|
|
|
|
|
|
|
var timer *time.Timer |
|
|
|
defer func() { |
|
|
|
if timer != nil { |
|
|
|
timer.Stop() |
|
|
|
} |
|
|
|
}() |
|
|
|
|
|
|
|
payload := make([]byte, 64) |
|
|
|
var punchErr error |
|
|
|
loop: |
|
|
|
for i := 0; ; i++ { |
|
|
|
if _, err := rand.Read(payload); err != nil { |
|
|
|
punchErr = err |
|
|
|
break |
|
|
|
} |
|
|
|
if _, err := pconn.UDPConn.WriteToUDP(payload, addr); err != nil { |
|
|
|
punchErr = err |
|
|
|
break |
|
|
|
} |
|
|
|
|
|
|
|
maxSleep := 10 * (i + 1) * (i + 1) // in ms
|
|
|
|
if maxSleep > 200 { |
|
|
|
maxSleep = 200 |
|
|
|
} |
|
|
|
d := 10*time.Millisecond + time.Duration(rand.Intn(maxSleep))*time.Millisecond |
|
|
|
if timer == nil { |
|
|
|
timer = time.NewTimer(d) |
|
|
|
} else { |
|
|
|
timer.Reset(d) |
|
|
|
} |
|
|
|
select { |
|
|
|
case c := <-connCh: |
|
|
|
t.holePunchingMx.Lock() |
|
|
|
delete(t.holePunching, key) |
|
|
|
t.holePunchingMx.Unlock() |
|
|
|
return c, nil |
|
|
|
case <-timer.C: |
|
|
|
case <-ctx.Done(): |
|
|
|
punchErr = ErrHolePunching |
|
|
|
break loop |
|
|
|
} |
|
|
|
} |
|
|
|
// we only arrive here if punchErr != nil
|
|
|
|
t.holePunchingMx.Lock() |
|
|
|
defer func() { |
|
|
|
delete(t.holePunching, key) |
|
|
|
t.holePunchingMx.Unlock() |
|
|
|
}() |
|
|
|
select { |
|
|
|
case c := <-t.holePunching[key].connCh: |
|
|
|
return c, nil |
|
|
|
default: |
|
|
|
return nil, punchErr |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Don't use mafmt.QUIC as we don't want to dial DNS addresses. Just /ip{4,6}/udp/quic
|
|
|
|
var dialMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_UDP), mafmt.Base(ma.P_QUIC)) |
|
|
|
|
|
|
|