mirror of https://github.com/libp2p/go-libp2p.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
380 lines
11 KiB
380 lines
11 KiB
package holepunch_test
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/libp2p/go-libp2p"
|
|
|
|
"github.com/libp2p/go-libp2p-core/host"
|
|
"github.com/libp2p/go-libp2p-core/network"
|
|
"github.com/libp2p/go-libp2p-core/peer"
|
|
"github.com/libp2p/go-libp2p-core/peerstore"
|
|
"github.com/libp2p/go-libp2p-testing/race"
|
|
|
|
relayv1 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv1/relay"
|
|
"github.com/libp2p/go-libp2p/p2p/protocol/holepunch"
|
|
holepunch_pb "github.com/libp2p/go-libp2p/p2p/protocol/holepunch/pb"
|
|
"github.com/libp2p/go-libp2p/p2p/protocol/identify"
|
|
|
|
"github.com/libp2p/go-msgio/protoio"
|
|
ma "github.com/multiformats/go-multiaddr"
|
|
manet "github.com/multiformats/go-multiaddr/net"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type mockEventTracer struct {
|
|
mutex sync.Mutex
|
|
events []*holepunch.Event
|
|
}
|
|
|
|
func (m *mockEventTracer) Trace(evt *holepunch.Event) {
|
|
m.mutex.Lock()
|
|
m.events = append(m.events, evt)
|
|
m.mutex.Unlock()
|
|
}
|
|
|
|
func (m *mockEventTracer) getEvents() []*holepunch.Event {
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
// copy the slice
|
|
return append([]*holepunch.Event{}, m.events...)
|
|
}
|
|
|
|
var _ holepunch.EventTracer = &mockEventTracer{}
|
|
|
|
func TestNoHolePunchIfDirectConnExists(t *testing.T) {
|
|
tr := &mockEventTracer{}
|
|
h1, hps := mkHostWithHolePunchSvc(t, holepunch.WithTracer(tr))
|
|
defer h1.Close()
|
|
h2, _ := mkHostWithHolePunchSvc(t)
|
|
defer h2.Close()
|
|
require.NoError(t, h1.Connect(context.Background(), peer.AddrInfo{
|
|
ID: h2.ID(),
|
|
Addrs: h2.Addrs(),
|
|
}))
|
|
time.Sleep(50 * time.Millisecond)
|
|
nc1 := len(h1.Network().ConnsToPeer(h2.ID()))
|
|
require.GreaterOrEqual(t, nc1, 1)
|
|
nc2 := len(h2.Network().ConnsToPeer(h1.ID()))
|
|
require.GreaterOrEqual(t, nc2, 1)
|
|
|
|
require.NoError(t, hps.DirectConnect(h2.ID()))
|
|
require.Equal(t, len(h1.Network().ConnsToPeer(h2.ID())), nc1)
|
|
require.Equal(t, len(h2.Network().ConnsToPeer(h1.ID())), nc2)
|
|
require.Empty(t, tr.getEvents())
|
|
}
|
|
|
|
func TestDirectDialWorks(t *testing.T) {
|
|
// mark all addresses as public
|
|
cpy := manet.Private4
|
|
manet.Private4 = []*net.IPNet{}
|
|
defer func() { manet.Private4 = cpy }()
|
|
|
|
tr := &mockEventTracer{}
|
|
h1, h1ps := mkHostWithHolePunchSvc(t, holepunch.WithTracer(tr))
|
|
defer h1.Close()
|
|
h2, _ := mkHostWithHolePunchSvc(t)
|
|
defer h2.Close()
|
|
h2.RemoveStreamHandler(holepunch.Protocol)
|
|
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), peerstore.ConnectedAddrTTL)
|
|
|
|
// try to hole punch without any connection and streams, if it works -> it's a direct connection
|
|
require.Len(t, h1.Network().ConnsToPeer(h2.ID()), 0)
|
|
require.NoError(t, h1ps.DirectConnect(h2.ID()))
|
|
require.GreaterOrEqual(t, len(h1.Network().ConnsToPeer(h2.ID())), 1)
|
|
require.GreaterOrEqual(t, len(h2.Network().ConnsToPeer(h1.ID())), 1)
|
|
events := tr.getEvents()
|
|
require.Len(t, events, 1)
|
|
require.Equal(t, events[0].Type, holepunch.DirectDialEvtT)
|
|
}
|
|
|
|
func TestEndToEndSimConnect(t *testing.T) {
|
|
tr := &mockEventTracer{}
|
|
h1, h2, relay, _ := makeRelayedHosts(t, holepunch.WithTracer(tr), true)
|
|
defer h1.Close()
|
|
defer h2.Close()
|
|
defer relay.Close()
|
|
|
|
// wait till a direct connection is complete
|
|
ensureDirectConn(t, h1, h2)
|
|
// ensure no hole-punching streams are open on either side
|
|
ensureNoHolePunchingStream(t, h1, h2)
|
|
var events []*holepunch.Event
|
|
require.Eventually(t,
|
|
func() bool {
|
|
events = tr.getEvents()
|
|
return len(events) == 3
|
|
},
|
|
time.Second,
|
|
10*time.Millisecond,
|
|
)
|
|
require.Equal(t, events[0].Type, holepunch.StartHolePunchEvtT)
|
|
require.Equal(t, events[1].Type, holepunch.HolePunchAttemptEvtT)
|
|
require.Equal(t, events[2].Type, holepunch.EndHolePunchEvtT)
|
|
}
|
|
|
|
func TestFailuresOnInitiator(t *testing.T) {
|
|
tcs := map[string]struct {
|
|
rhandler func(s network.Stream)
|
|
errMsg string
|
|
holePunchTimeout time.Duration
|
|
}{
|
|
"responder does NOT send a CONNECT message": {
|
|
rhandler: func(s network.Stream) {
|
|
wr := protoio.NewDelimitedWriter(s)
|
|
wr.WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_SYNC.Enum()})
|
|
},
|
|
errMsg: "expect CONNECT message, got SYNC",
|
|
},
|
|
"responder does NOT support protocol": {
|
|
rhandler: nil,
|
|
},
|
|
"unable to READ CONNECT message from responder": {
|
|
rhandler: func(s network.Stream) {
|
|
s.Reset()
|
|
},
|
|
errMsg: "failed to read CONNECT message",
|
|
},
|
|
"responder does NOT reply within hole punch deadline": {
|
|
holePunchTimeout: 10 * time.Millisecond,
|
|
rhandler: func(s network.Stream) {
|
|
time.Sleep(5 * time.Second)
|
|
},
|
|
errMsg: "i/o deadline reached",
|
|
},
|
|
}
|
|
|
|
for name, tc := range tcs {
|
|
t.Run(name, func(t *testing.T) {
|
|
if tc.holePunchTimeout != 0 {
|
|
cpy := holepunch.StreamTimeout
|
|
holepunch.StreamTimeout = tc.holePunchTimeout
|
|
defer func() { holepunch.StreamTimeout = cpy }()
|
|
}
|
|
|
|
tr := &mockEventTracer{}
|
|
h1, h2, relay, _ := makeRelayedHosts(t, holepunch.WithTracer(tr), false)
|
|
defer h1.Close()
|
|
defer h2.Close()
|
|
defer relay.Close()
|
|
hps := addHolePunchService(t, h2)
|
|
|
|
if tc.rhandler != nil {
|
|
h1.SetStreamHandler(holepunch.Protocol, tc.rhandler)
|
|
} else {
|
|
h1.RemoveStreamHandler(holepunch.Protocol)
|
|
}
|
|
|
|
err := hps.DirectConnect(h1.ID())
|
|
require.Error(t, err)
|
|
if tc.errMsg != "" {
|
|
require.Contains(t, err.Error(), tc.errMsg)
|
|
}
|
|
})
|
|
|
|
}
|
|
}
|
|
|
|
func TestFailuresOnResponder(t *testing.T) {
|
|
tcs := map[string]struct {
|
|
initiator func(s network.Stream)
|
|
errMsg string
|
|
holePunchTimeout time.Duration
|
|
}{
|
|
"initiator does NOT send a CONNECT message": {
|
|
initiator: func(s network.Stream) {
|
|
protoio.NewDelimitedWriter(s).WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_SYNC.Enum()})
|
|
},
|
|
errMsg: "expected CONNECT message",
|
|
},
|
|
"initiator does NOT send a SYNC message after a Connect message": {
|
|
initiator: func(s network.Stream) {
|
|
w := protoio.NewDelimitedWriter(s)
|
|
w.WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_CONNECT.Enum()})
|
|
w.WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_CONNECT.Enum()})
|
|
},
|
|
errMsg: "expected SYNC message",
|
|
},
|
|
"initiator does NOT reply within hole punch deadline": {
|
|
holePunchTimeout: 10 * time.Millisecond,
|
|
initiator: func(s network.Stream) {
|
|
protoio.NewDelimitedWriter(s).WriteMsg(&holepunch_pb.HolePunch{Type: holepunch_pb.HolePunch_CONNECT.Enum()})
|
|
time.Sleep(10 * time.Second)
|
|
},
|
|
errMsg: "i/o deadline reached",
|
|
},
|
|
}
|
|
|
|
for name, tc := range tcs {
|
|
t.Run(name, func(t *testing.T) {
|
|
if tc.holePunchTimeout != 0 {
|
|
cpy := holepunch.StreamTimeout
|
|
holepunch.StreamTimeout = tc.holePunchTimeout
|
|
defer func() { holepunch.StreamTimeout = cpy }()
|
|
}
|
|
|
|
tr := &mockEventTracer{}
|
|
h1, h2, relay, _ := makeRelayedHosts(t, holepunch.WithTracer(tr), false)
|
|
defer h1.Close()
|
|
defer h2.Close()
|
|
defer relay.Close()
|
|
|
|
s, err := h2.NewStream(context.Background(), h1.ID(), holepunch.Protocol)
|
|
require.NoError(t, err)
|
|
|
|
go tc.initiator(s)
|
|
|
|
getTracerError := func(tr *mockEventTracer) []string {
|
|
var errs []string
|
|
events := tr.getEvents()
|
|
for _, ev := range events {
|
|
if errEv, ok := ev.Evt.(*holepunch.ProtocolErrorEvt); ok {
|
|
errs = append(errs, errEv.Error)
|
|
}
|
|
}
|
|
return errs
|
|
}
|
|
|
|
require.Eventually(t, func() bool { return len(getTracerError(tr)) > 0 }, 5*time.Second, 100*time.Millisecond)
|
|
errs := getTracerError(tr)
|
|
require.Len(t, errs, 1)
|
|
require.Contains(t, errs[0], tc.errMsg)
|
|
})
|
|
|
|
}
|
|
}
|
|
|
|
func ensureNoHolePunchingStream(t *testing.T, h1, h2 host.Host) {
|
|
require.Eventually(t, func() bool {
|
|
for _, c := range h1.Network().ConnsToPeer(h2.ID()) {
|
|
for _, s := range c.GetStreams() {
|
|
if s.ID() == string(holepunch.Protocol) {
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
}, 5*time.Second, 50*time.Millisecond)
|
|
|
|
require.Eventually(t, func() bool {
|
|
for _, c := range h2.Network().ConnsToPeer(h1.ID()) {
|
|
for _, s := range c.GetStreams() {
|
|
if s.ID() == string(holepunch.Protocol) {
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
}, 5*time.Second, 50*time.Millisecond)
|
|
}
|
|
|
|
func ensureDirectConn(t *testing.T, h1, h2 host.Host) {
|
|
require.Eventually(t, func() bool {
|
|
for _, c := range h1.Network().ConnsToPeer(h2.ID()) {
|
|
if _, err := c.RemoteMultiaddr().ValueForProtocol(ma.P_CIRCUIT); err != nil {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}, 5*time.Second, 50*time.Millisecond)
|
|
|
|
require.Eventually(t, func() bool {
|
|
for _, c := range h2.Network().ConnsToPeer(h1.ID()) {
|
|
if _, err := c.RemoteMultiaddr().ValueForProtocol(ma.P_CIRCUIT); err != nil {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}, 5*time.Second, 50*time.Millisecond)
|
|
}
|
|
|
|
func mkHostWithStaticAutoRelay(t *testing.T, ctx context.Context, relay host.Host) host.Host {
|
|
if race.WithRace() {
|
|
t.Skip("modifying manet.Private4 is racy")
|
|
}
|
|
pi := peer.AddrInfo{
|
|
ID: relay.ID(),
|
|
Addrs: relay.Addrs(),
|
|
}
|
|
|
|
cpy := manet.Private4
|
|
manet.Private4 = []*net.IPNet{}
|
|
defer func() { manet.Private4 = cpy }()
|
|
|
|
h, err := libp2p.New(
|
|
libp2p.ListenAddrs(ma.StringCast("/ip4/127.0.0.1/tcp/0")),
|
|
libp2p.EnableRelay(),
|
|
libp2p.EnableAutoRelay(),
|
|
libp2p.ForceReachabilityPrivate(),
|
|
libp2p.StaticRelays([]peer.AddrInfo{pi}),
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
// wait till we have a relay addr
|
|
require.Eventually(t, func() bool {
|
|
for _, a := range h.Addrs() {
|
|
if _, err := a.ValueForProtocol(ma.P_CIRCUIT); err == nil {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}, 5*time.Second, 50*time.Millisecond)
|
|
return h
|
|
}
|
|
|
|
func makeRelayedHosts(t *testing.T, h1Opt holepunch.Option, addHolePuncher bool) (h1, h2, relay host.Host, hps *holepunch.Service) {
|
|
t.Helper()
|
|
h1, _ = mkHostWithHolePunchSvc(t, h1Opt)
|
|
var err error
|
|
relay, err = libp2p.New(libp2p.ListenAddrs(ma.StringCast("/ip4/127.0.0.1/tcp/0")), libp2p.DisableRelay())
|
|
require.NoError(t, err)
|
|
|
|
_, err = relayv1.NewRelay(relay)
|
|
require.NoError(t, err)
|
|
|
|
h2 = mkHostWithStaticAutoRelay(t, context.Background(), relay)
|
|
if addHolePuncher {
|
|
hps = addHolePunchService(t, h2)
|
|
}
|
|
|
|
// h1 has a relay addr
|
|
// h2 should connect to the relay addr
|
|
var raddr ma.Multiaddr
|
|
for _, a := range h2.Addrs() {
|
|
if _, err := a.ValueForProtocol(ma.P_CIRCUIT); err == nil {
|
|
raddr = a
|
|
break
|
|
}
|
|
}
|
|
require.NotEmpty(t, raddr)
|
|
require.NoError(t, h1.Connect(context.Background(), peer.AddrInfo{
|
|
ID: h2.ID(),
|
|
Addrs: []ma.Multiaddr{raddr},
|
|
}))
|
|
return
|
|
}
|
|
|
|
func addHolePunchService(t *testing.T, h host.Host) *holepunch.Service {
|
|
t.Helper()
|
|
ids, err := identify.NewIDService(h)
|
|
require.NoError(t, err)
|
|
hps, err := holepunch.NewService(h, ids)
|
|
require.NoError(t, err)
|
|
return hps
|
|
}
|
|
|
|
func mkHostWithHolePunchSvc(t *testing.T, opts ...holepunch.Option) (host.Host, *holepunch.Service) {
|
|
t.Helper()
|
|
h, err := libp2p.New(libp2p.ListenAddrs(ma.StringCast("/ip4/127.0.0.1/tcp/0"), ma.StringCast("/ip6/::1/tcp/0")))
|
|
require.NoError(t, err)
|
|
ids, err := identify.NewIDService(h)
|
|
require.NoError(t, err)
|
|
hps, err := holepunch.NewService(h, ids, opts...)
|
|
require.NoError(t, err)
|
|
return h, hps
|
|
}
|
|
|